import fancy_gym from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder import wandb from wandb.integration.sb3 import WandbCallback import gymnasium as gym import yaml import os import random import copy import collections.abc from functools import partial import pdb d = pdb.set_trace try: import pyslurm except ImportError: slurm_avaible = False else: slurm_avaible = True PCA = None # TODO: Implement Testing # TODO: Implement PCA # TODO: Implement Slurm # TODO: Implement Parallel def load_config(filename, name): config = _load_config(filename, name) deep_expand_vars(config, config=config) consume(config, 'vars', {}) return config def _load_config(filename, name): with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: if 'name' in doc: if doc['name'] == name: if 'import' in doc: imports = reversed(doc['import'].split(',')) del doc['import'] for imp in imports: if imp == "$": imp = ':DEFAULT' rel_path, *opt = imp.split(':') if len(opt) == 0: nested_name = 'DEFAULT' elif len(opt) == 1: nested_name = opt[0] else: raise Exception('Malformed import statement. Must be , or for file:DEFAULT.') nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename child = _load_config(nested_path, nested_name) doc = deep_update(child, doc) return doc raise Exception(f'Unable to find experiment <{name}> in <{filename}>') def deep_update_old(d, u): for k, v in u.items(): if isinstance(v, collections.abc.Mapping): d[k] = deep_update_old(d.get(k, {}), v) else: d[k] = v return d def deep_update(d, u): for kstr, v in u.items(): ks = kstr.split('.') head = d for k in ks: last_head = head head = head[k] if isinstance(v, collections.abc.Mapping): last_head[ks[-1]] = deep_update(d.get(k, {}), v) else: last_head[ks[-1]] = v return d def expand_vars(string, **kwargs): if isinstance(string, str): return string.format(**kwargs) return string def apply_nested(d, f): for k, v in d.items(): if isinstance(v, dict): apply_nested(v, f) elif isinstance(v, list): for i, e in enumerate(v): apply_nested({'PTR': d[k][i]}, f) else: d[k] = f(v) def deep_expand_vars(dict, **kwargs): apply_nested(dict, lambda x: expand_vars(x, **kwargs)) def consume(conf, keys, default=None): keys_arr = keys.split('.') if len(keys_arr) == 1: k = keys_arr[0] if default != None: val = conf.get(k, default) else: val = conf[k] if k in conf: del conf[k] return val child = conf.get(keys_arr[0], {}) child_keys = '.'.join(keys_arr[1:]) return consume(child, child_keys, default=default) def run_local(filename, name, job_num=None): config = load_config(filename, name) if consume(config, 'sweep.enable', False): sweepC = consume(config, 'sweep') project = consume(config, 'wandb.project') sweep_id = wandb.sweep( sweep=sweepC, project=project ) runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent']) else: consume(config, 'sweep', {}) run_single(config) def run_from_sweep(orig_config, runnerName, project, wandbC): runner = Runners[runnerName] with wandb.init( project=project, **wandbC ) as run: config = copy.deepcopy(orig_config) deep_update(config, wandb.config) runner(run, config) assert config == {}, ('Config was not completely consumed: ', config) def run_slurm(filename, name): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = load_config(filename, name) slurmC = consume(config, 'slurm') s_name = consume(slurmC, 'name') python_script = 'main.py' sh_lines = consume(slurmC, 'sh_lines', []) if venv := consume(slurmC, 'venv', False): sh_lines += [f'source activate {venv}'] sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] script = " && ".join(sh_lines) num_jobs = 1 last_job_idx = num_jobs - 1 num_parallel_jobs = min(consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs) array = f'0-{last_job_idx}%{num_parallel_jobs}' job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm']) job_id = job.submit() print(f'[i] Job submitted to slurm with id {job_id}') def run_single(config): runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) runner = Runners[runnerName] with wandb.init( project=consume(wandbC, 'project'), config=config, **wandbC ) as run: runner(run, config) assert config == {}, ('Config was not completely consumed: ', config) def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("config_file", nargs='?', default=None) parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-j", "--job_num", default=None) args = parser.parse_args() if args.worker: raise Exception('Not yet implemented') assert args.config_file != None, 'Need to supply config file.' if args.slurm: run_slurm(args.config_file, args.experiment) else: run_local(args.config_file, args.experiment, args.job_num) def debug_runner(run, config): print(config) for k in list(config.keys()): del config[k] import time time.sleep(10) def sb3_runner(run, config): videoC, testC, envC, algoC, pcaC = consume(config, 'video', {}), consume(config, 'test', {}), consume(config, 'env', {}), consume(config, 'algo', {}), consume(config, 'pca', {}) assert config == {} env = DummyVecEnv([make_env_func(envC)]) if consume(videoC, 'enable', False): env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) assert consume(algoC, 'name') == 'PPO' policy_name = consume(algoC, 'policy_name') total_timesteps = consume(algoC, 'total_timesteps') model = PPO(policy_name, env, **algoC) if consume(pcaC, 'enable', False): model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) model.learn( total_timesteps=total_timesteps, callback=WandbCallback() ) def make_env_func(env_conf): conf = copy.deepcopy(env_conf) name = consume(conf, 'name') legacy_fancy = consume(conf, 'legacy_fancy', False) wrappers = consume(conf, 'wrappers', []) def func(): if legacy_fancy: # TODO: Remove when no longer needed fancy_gym.make(name, **conf) else: env = gym.make(name, **conf) # TODO: Implement wrappers env = Monitor(env) return env return func Runners = { 'sb3': sb3_runner, 'debug': debug_runner } if __name__ == '__main__': main()