From 320ec6dd0352be0cc365d6a27a48a549485e728a Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 5 Jul 2023 19:29:21 +0200 Subject: [PATCH] Added lots of things --- main.py | 198 +++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 167 insertions(+), 31 deletions(-) diff --git a/main.py b/main.py index 7ac67cd..469b30a 100644 --- a/main.py +++ b/main.py @@ -7,12 +7,38 @@ from wandb.integration.sb3 import WandbCallback import gymnasium as gym import yaml import os +import random +import copy import collections.abc +import pdb +d = pdb.set_trace + + +try: + import pyslurm +except ImportError: + slurm_avaible = False +else: + slurm_avaible = True + + PCA = None +# TODO: Implement Testing +# TODO: Implement PCA +# TODO: Implement Slurm +# TODO: Implement Parallel + def load_config(filename, name): + config = _load_config(filename, name) + deep_expand_vars(config, config=config) + consume(config, 'vars', {}) + return config + + +def _load_config(filename, name): with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: @@ -28,11 +54,12 @@ def load_config(filename, name): elif len(opt) == 1: nested_name = opt[0] else: - raise Exception() + raise Exception('Malformed import statement. Must be , or for file:DEFAULT.') nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename - child = load_config(nested_path, nested_name) + child = _load_config(nested_path, nested_name) doc = deep_update(child, doc) return doc + raise Exception(f'Unable to find experiment <{name}> in <{filename}>') def deep_update(d, u): @@ -44,7 +71,44 @@ def deep_update(d, u): return d -def run(filename, name): +def expand_vars(string, **kwargs): + if isinstance(string, str): + return string.format(**kwargs) + return string + + +def apply_nested(d, f): + for k, v in d.items(): + if isinstance(v, dict): + apply_nested(v, f) + elif isinstance(v, list): + for i, e in enumerate(v): + apply_nested({'PTR': d[k][i]}, f) + else: + d[k] = f(v) + + +def deep_expand_vars(dict, **kwargs): + apply_nested(dict, lambda x: expand_vars(x, **kwargs)) + + +def consume(conf, keys, default=None): + keys_arr = keys.split('.') + if len(keys_arr) == 1: + k = keys_arr[0] + if default != None: + val = conf.get(k, default) + else: + val = conf[k] + if k in conf: + del conf[k] + return val + child = conf[keys_arr[0]] + child_keys = '.'.join(keys_arr[1:]) + return consume(child, child_keys, default=default) + + +def run_local(filename, name, job_num=None): config = load_config(filename, name) if 'sweep' in config and config['sweep']['enable']: sweepC = config['sweep'] @@ -58,54 +122,126 @@ def run(filename, name): run_single(config) +def run_slurm(filename, name): + assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' + config = load_config(filename, name) + slurmC = consume(config, 'slurm') + s_name = consume(slurmC, 'name') + + python_script = 'main.py' + sh_lines = consume(slurmC, 'sh_lines', []) + if venv := consume(slurmC, 'venv', False): + sh_lines += [f'source activate {venv}'] + sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] + script = " && ".join(sh_lines) + + num_jobs = 1 + + last_job_idx = num_jobs - 1 + num_parallel_jobs = min(consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs) + array = f'0-{last_job_idx}%{num_parallel_jobs}' + job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm']) + job_id = job.submit() + print(f'[i] Job submitted to slurm with id {job_id}') + + def run_single(config): - videoC, testC, envC, algoC, pcaC = config.get('video', {}), config.get('test', {}), config.get('env', {}), config.get('algo', {}), config.get('pca', {}) + runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) + + try: + runner = Runners[runnerName] + except: + d() with wandb.init( - project=config['project'], + project=consume(wandbC, 'project'), config=config, - sync_tensorboard=True, - monitor_gym=True, - save_code=True, + **wandbC ) as run: - env = DummyVecEnv([make_env_func(envC)]) - if videoC.get('enable', False): - env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) + runner(run, config) - assert algoC['name'] == 'PPO' - del algoC['name'] - policy_name = algoC['policy_name'] - del algoC['policy_name'] - model = PPO(policy_name env, **algo) + assert config == {}, ('Config was not completely consumed: ', config) - if pcaC.get('enable', False): - del pcaC['enable'] - model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) - model.learn( - total_timesteps=config["total_timesteps"], - callback=WandbCallback() - ) +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("config_file", nargs='?', default=None) + parser.add_argument("experiment", nargs='?', default='DEFAULT') + parser.add_argument("-s", "--slurm", action="store_true") + parser.add_argument("-w", "--worker", action="store_true") + parser.add_argument("-j", "--job_num", default=None) + + args = parser.parse_args() + + if args.worker: + raise Exception('Not yet implemented') + + assert args.config_file != None, 'Need to supply config file.' + if args.slurm: + run_slurm(args.config_file, args.experiment) + else: + run_local(args.config_file, args.experiment, args.job_num) + + +def debug_runner(run, config): + print(config) + for k in list(config.keys()): + del config[k] + import time + time.sleep(10) + + +def sb3_runner(run, config): + videoC, testC, envC, algoC, pcaC = consume(config, 'video', {}), consume(config, 'test', {}), consume(config, 'env', {}), consume(config, 'algo', {}), consume(config, 'pca', {}) + assert config == {} + + env = DummyVecEnv([make_env_func(envC)]) + if consume(videoC, 'enable', False): + env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) + + assert consume(algoC, 'name') == 'PPO' + policy_name = consume(algoC, 'policy_name') + + total_timesteps = consume(algoC, 'total_timesteps') + + model = PPO(policy_name, env, **algoC) + + if consume(pcaC, 'enable', False): + model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) + + model.learn( + total_timesteps=total_timesteps, + callback=WandbCallback() + ) def make_env_func(env_conf): + conf = copy.deepcopy(env_conf) + name = consume(conf, 'name') + legacy_fancy = consume(conf, 'legacy_fancy', False) + wrappers = consume(conf, 'wrappers', []) + def func(): - legacy_fancy = env_conf.get('legacy_fancy', False) - del env_conf['name'] - if 'legacy_fancy' in env_conf: - del env_conf['legacy_fancy'] if legacy_fancy: # TODO: Remove when no longer needed - fancy_gym.make(env_conf['name'], **env_conf) + fancy_gym.make(name, **conf) else: - env = gym.make(env_conf['name'], **env_conf) + env = gym.make(name, **conf) + + # TODO: Implement wrappers + env = Monitor(env) return env return func -def main(): - run() +Runners = { + 'sb3': sb3_runner, + 'debug': debug_runner +} + if __name__ == '__main__': main()