import fancy_gym from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder import wandb from wandb.integration.sb3 import WandbCallback import gymnasium as gym import yaml import os import collections.abc PCA = None def load_config(filename, name): with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: if 'name' in doc: if doc['name'] == name: if 'import' in doc: imports = reversed(doc['import'].split(',')) del doc['import'] for imp in imports: rel_path, *opt = imp.split(':') if len(opt) == 0: nested_name = 'DEFAULT' elif len(opt) == 1: nested_name = opt[0] else: raise Exception() nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename child = load_config(nested_path, nested_name) doc = deep_update(child, doc) return doc def deep_update(d, u): for k, v in u.items(): if isinstance(v, collections.abc.Mapping): d[k] = deep_update(d.get(k, {}), v) else: d[k] = v return d def run(filename, name): config = load_config(filename, name) if 'sweep' in config and config['sweep']['enable']: sweepC = config['sweep'] del sweepC['enable'] sweep_id = wandb.sweep( sweep=sweepC, project=config['project'] ) wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent']) else: run_single(config) def run_single(config): videoC, testC, envC, algoC, pcaC = config.get('video', {}), config.get('test', {}), config.get('env', {}), config.get('algo', {}), config.get('pca', {}) with wandb.init( project=config['project'], config=config, sync_tensorboard=True, monitor_gym=True, save_code=True, ) as run: env = DummyVecEnv([make_env_func(envC)]) if videoC.get('enable', False): env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) assert algoC['name'] == 'PPO' del algoC['name'] policy_name = algoC['policy_name'] del algoC['policy_name'] model = PPO(policy_name env, **algo) if pcaC.get('enable', False): del pcaC['enable'] model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) model.learn( total_timesteps=config["total_timesteps"], callback=WandbCallback() ) def make_env_func(env_conf): def func(): legacy_fancy = env_conf.get('legacy_fancy', False) del env_conf['name'] if 'legacy_fancy' in env_conf: del env_conf['legacy_fancy'] if legacy_fancy: # TODO: Remove when no longer needed fancy_gym.make(env_conf['name'], **env_conf) else: env = gym.make(env_conf['name'], **env_conf) env = Monitor(env) return env return func def main(): run() if __name__ == '__main__': main()