From ea74d6a7123378204ff45a39aac6196a52e8e9ba Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 6 Jul 2023 18:06:20 +0200 Subject: [PATCH] Maybe it's a package now --- example.py | 73 ++++++++++ config.yaml => example.yaml | 0 main.py | 271 ------------------------------------ setup.py | 12 ++ slate/__init__.py | 1 + slate/slate.py | 206 +++++++++++++++++++++++++++ 6 files changed, 292 insertions(+), 271 deletions(-) create mode 100644 example.py rename config.yaml => example.yaml (100%) delete mode 100644 main.py create mode 100644 setup.py create mode 100644 slate/__init__.py create mode 100644 slate/slate.py diff --git a/example.py b/example.py new file mode 100644 index 0000000..808863e --- /dev/null +++ b/example.py @@ -0,0 +1,73 @@ +from slate import Slate + +import fancy_gym +from stable_baselines3 import PPO +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +from wandb.integration.sb3 import WandbCallback +import gymnasium as gym +import copy + +PCA = None + + +def debug_runner(slate, run, config): + print(config) + for k in list(config.keys()): + del config[k] + import time + time.sleep(10) + + +def sb3_runner(slate, run, config): + videoC, testC, envC, algoC, pcaC = slate.consume(config, 'video', {}), slate.consume(config, 'test', {}), slate.consume(config, + 'env', {}), slate.consume(config, 'algo', {}), slate.consume(config, 'pca', {}) + assert config == {} + + env = DummyVecEnv([make_env_func(slate, envC)]) + if slate.consume(videoC, 'enable', False): + env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) + + assert slate.consume(algoC, 'name') == 'PPO' + policy_name = slate.consume(algoC, 'policy_name') + + total_timesteps = config.get('run', {}).get('total_timesteps', {}) + + model = PPO(policy_name, env, **algoC) + + if slate.consume(pcaC, 'enable', False): + model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) + + model.learn( + total_timesteps=total_timesteps, + callback=WandbCallback() + ) + + +def make_env_func(slate, env_conf): + conf = copy.deepcopy(env_conf) + name = slate.consume(conf, 'name') + legacy_fancy = slate.consume(conf, 'legacy_fancy', False) + wrappers = slate.consume(conf, 'wrappers', []) + + def func(): + if legacy_fancy: # TODO: Remove when no longer needed + fancy_gym.make(name, **conf) + else: + env = gym.make(name, **conf) + + # TODO: Implement wrappers + + env = Monitor(env) + return env + return func + + +runners = { + 'sb3': sb3_runner, + 'debug': debug_runner +} + +if __name__ == '__main__': + slate = Slate(runners) + slate.from_args() diff --git a/config.yaml b/example.yaml similarity index 100% rename from config.yaml rename to example.yaml diff --git a/main.py b/main.py deleted file mode 100644 index 66aea16..0000000 --- a/main.py +++ /dev/null @@ -1,271 +0,0 @@ -#import fancy_gym -#from stable_baselines3 import PPO -#from stable_baselines3.common.monitor import Monitor -#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder -import wandb -from wandb.integration.sb3 import WandbCallback -#import gymnasium as gym -import yaml -import os -import random -import copy -import collections.abc -from functools import partial - -import pdb -d = pdb.set_trace - -try: - import pyslurm -except ImportError: - slurm_avaible = False -else: - slurm_avaible = True - -PCA = None - -# TODO: Implement Slurm -# TODO: Implement Parallel -# TODO: Implement Testing -# TODO: Implement Ablative -# TODO: Implement PCA - - -def load_config(filename, name): - config, stack = _load_config(filename, name) - print('[i] Merged Configs: ', stack) - deep_expand_vars(config, config=config) - consume(config, 'vars', {}) - return config - - -def _load_config(filename, name, stack=[]): - stack.append(f'{filename}:{name}') - with open(filename, 'r') as f: - docs = yaml.safe_load_all(f) - for doc in docs: - if 'name' in doc: - if doc['name'] == name: - if 'import' in doc: - imports = doc['import'].split(',') - del doc['import'] - for imp in imports: - if imp[0] == ' ': - imp = imp[1:] - if imp == "$": - imp = ':DEFAULT' - rel_path, *opt = imp.split(':') - if len(opt) == 0: - nested_name = 'DEFAULT' - elif len(opt) == 1: - nested_name = opt[0] - else: - raise Exception('Malformed import statement. Must be , , for file:DEFAULT or for :DEFAULT.') - nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename - child, stack = _load_config(nested_path, nested_name, stack=stack) - doc = deep_update(child, doc) - return doc, stack - raise Exception(f'Unable to find experiment <{name}> in <{filename}>') - - -def deep_update(d, u): - for kstr, v in u.items(): - ks = kstr.split('.') - head = d - for k in ks: - last_head = head - if k not in head: - head[k] = {} - head = head[k] - if isinstance(v, collections.abc.Mapping): - last_head[ks[-1]] = deep_update(d.get(k, {}), v) - else: - last_head[ks[-1]] = v - return d - - -def expand_vars(string, **kwargs): - if isinstance(string, str): - return string.format(**kwargs) - return string - - -def apply_nested(d, f): - for k, v in d.items(): - if isinstance(v, dict): - apply_nested(v, f) - elif isinstance(v, list): - for i, e in enumerate(v): - apply_nested({'PTR': d[k][i]}, f) - else: - d[k] = f(v) - - -def deep_expand_vars(dict, **kwargs): - apply_nested(dict, lambda x: expand_vars(x, **kwargs)) - - -def consume(conf, key, default=None): - keys_arr = key.split('.') - if len(keys_arr) == 1: - k = keys_arr[0] - if default != None: - val = conf.get(k, default) - else: - val = conf[k] - if k in conf: - del conf[k] - return val - child = conf.get(keys_arr[0], {}) - child_keys = '.'.join(keys_arr[1:]) - return consume(child, child_keys, default=default) - - -def run_local(filename, name, job_num=None): - config = load_config(filename, name) - if consume(config, 'sweep.enable', False): - sweepC = consume(config, 'sweep') - project = consume(config, 'wandb.project') - sweep_id = wandb.sweep( - sweep=sweepC, - project=project - ) - runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) - wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent']) - else: - consume(config, 'sweep', {}) - run_single(config) - - -def run_from_sweep(orig_config, runnerName, project, wandbC): - runner = Runners[runnerName] - - with wandb.init( - project=project, - **wandbC - ) as run: - config = copy.deepcopy(orig_config) - deep_update(config, wandb.config) - runner(run, config) - - assert config == {}, ('Config was not completely consumed: ', config) - - -def run_slurm(filename, name): - assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' - config = load_config(filename, name) - slurmC = consume(config, 'slurm') - s_name = consume(slurmC, 'name') - - python_script = 'main.py' - sh_lines = consume(slurmC, 'sh_lines', []) - if venv := consume(slurmC, 'venv', False): - sh_lines += [f'source activate {venv}'] - sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] - script = " && ".join(sh_lines) - - num_jobs = 1 - - last_job_idx = num_jobs - 1 - num_parallel_jobs = min(consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs) - array = f'0-{last_job_idx}%{num_parallel_jobs}' - job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm']) - job_id = job.submit() - print(f'[i] Job submitted to slurm with id {job_id}') - - -def run_single(config): - runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) - - runner = Runners[runnerName] - - with wandb.init( - project=consume(wandbC, 'project'), - config=config, - **wandbC - ) as run: - runner(run, config) - - assert config == {}, ('Config was not completely consumed: ', config) - - -def main(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("config_file", nargs='?', default=None) - parser.add_argument("experiment", nargs='?', default='DEFAULT') - parser.add_argument("-s", "--slurm", action="store_true") - parser.add_argument("-w", "--worker", action="store_true") - parser.add_argument("-j", "--job_num", default=None) - - args = parser.parse_args() - - if args.worker: - raise Exception('Not yet implemented') - - assert args.config_file != None, 'Need to supply config file.' - if args.slurm: - run_slurm(args.config_file, args.experiment) - else: - run_local(args.config_file, args.experiment, args.job_num) - - -def debug_runner(run, config): - print(config) - for k in list(config.keys()): - del config[k] - import time - time.sleep(10) - - -def sb3_runner(run, config): - videoC, testC, envC, algoC, pcaC = consume(config, 'video', {}), consume(config, 'test', {}), consume(config, 'env', {}), consume(config, 'algo', {}), consume(config, 'pca', {}) - assert config == {} - - env = DummyVecEnv([make_env_func(envC)]) - if consume(videoC, 'enable', False): - env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) - - assert consume(algoC, 'name') == 'PPO' - policy_name = consume(algoC, 'policy_name') - - total_timesteps = config.get('run', {}).get('total_timesteps', {}) - - model = PPO(policy_name, env, **algoC) - - if consume(pcaC, 'enable', False): - model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) - - model.learn( - total_timesteps=total_timesteps, - callback=WandbCallback() - ) - - -def make_env_func(env_conf): - conf = copy.deepcopy(env_conf) - name = consume(conf, 'name') - legacy_fancy = consume(conf, 'legacy_fancy', False) - wrappers = consume(conf, 'wrappers', []) - - def func(): - if legacy_fancy: # TODO: Remove when no longer needed - fancy_gym.make(name, **conf) - else: - env = gym.make(name, **conf) - - # TODO: Implement wrappers - - env = Monitor(env) - return env - return func - - -Runners = { - 'sb3': sb3_runner, - 'debug': debug_runner -} - -if __name__ == '__main__': - main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..3db1e16 --- /dev/null +++ b/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup, find_packages + +setup( + name='slate', + version='1.0.0', + # url='https://github.com/mypackage.git', + # author='Author Name', + # author_email='author@gmail.com', + # description='Description of my package', + packages=['.'], + install_requires=[], +) diff --git a/slate/__init__.py b/slate/__init__.py new file mode 100644 index 0000000..a149628 --- /dev/null +++ b/slate/__init__.py @@ -0,0 +1 @@ +from slate import Slate diff --git a/slate/slate.py b/slate/slate.py new file mode 100644 index 0000000..2e9b5e3 --- /dev/null +++ b/slate/slate.py @@ -0,0 +1,206 @@ +#import fancy_gym +#from stable_baselines3 import PPO +#from stable_baselines3.common.monitor import Monitor +#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +import wandb +from wandb.integration.sb3 import WandbCallback +#import gymnasium as gym +import yaml +import os +import random +import copy +import collections.abc +from functools import partial + +import pdb +d = pdb.set_trace + +try: + import pyslurm +except ImportError: + slurm_avaible = False +else: + slurm_avaible = True + +# TODO: Implement Slurm +# TODO: Implement Parallel +# TODO: Implement Testing +# TODO: Implement Ablative +# TODO: Implement PCA + + +class Slate(): + def __init__(self, runners): + self.runners = runners + + def load_config(self, filename, name): + config, stack = self._load_config(filename, name) + print('[i] Merged Configs: ', stack) + self.deep_expand_vars(config, config=config) + self.consume(config, 'vars', {}) + return config + + def _load_config(self, filename, name, stack=[]): + stack.append(f'{filename}:{name}') + with open(filename, 'r') as f: + docs = yaml.safe_load_all(f) + for doc in docs: + if 'name' in doc: + if doc['name'] == name: + if 'import' in doc: + imports = doc['import'].split(',') + del doc['import'] + for imp in imports: + if imp[0] == ' ': + imp = imp[1:] + if imp == "$": + imp = ':DEFAULT' + rel_path, *opt = imp.split(':') + if len(opt) == 0: + nested_name = 'DEFAULT' + elif len(opt) == 1: + nested_name = opt[0] + else: + raise Exception('Malformed import statement. Must be , , for file:DEFAULT or for :DEFAULT.') + nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename + child, stack = self._load_config(nested_path, nested_name, stack=stack) + doc = self.deep_update(child, doc) + return doc, stack + raise Exception(f'Unable to find experiment <{name}> in <{filename}>') + + def deep_update(self, d, u): + for kstr, v in u.items(): + ks = kstr.split('.') + head = d + for k in ks: + last_head = head + if k not in head: + head[k] = {} + head = head[k] + if isinstance(v, collections.abc.Mapping): + last_head[ks[-1]] = self.deep_update(d.get(k, {}), v) + else: + last_head[ks[-1]] = v + return d + + def expand_vars(self, string, **kwargs): + if isinstance(string, str): + return string.format(**kwargs) + return string + + def apply_nested(self, d, f): + for k, v in d.items(): + if isinstance(v, dict): + self.apply_nested(v, f) + elif isinstance(v, list): + for i, e in enumerate(v): + self.apply_nested({'PTR': d[k][i]}, f) + else: + d[k] = f(v) + + def deep_expand_vars(self, dict, **kwargs): + self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) + + def consume(self, conf, key, default=None): + keys_arr = key.split('.') + if len(keys_arr) == 1: + k = keys_arr[0] + if default != None: + val = conf.get(k, default) + else: + val = conf[k] + if k in conf: + del conf[k] + return val + child = conf.get(keys_arr[0], {}) + child_keys = '.'.join(keys_arr[1:]) + return self.consume(child, child_keys, default=default) + + def run_local(self, filename, name, job_num=None): + config = self.load_config(filename, name) + if self.consume(config, 'sweep.enable', False): + sweepC = self.consume(config, 'sweep') + project = self.consume(config, 'wandb.project') + sweep_id = wandb.sweep( + sweep=sweepC, + project=project + ) + runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) + wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent']) + else: + self.consume(config, 'sweep', {}) + self.run_single(config) + + def _run_from_sweep(self, orig_config, runnerName, project, wandbC): + runner = self.runners[runnerName] + + with wandb.init( + project=project, + **wandbC + ) as run: + config = copy.deepcopy(orig_config) + self.deep_update(config, wandb.config) + runner(run, config) + + assert config == {}, ('Config was not completely consumed: ', config) + + def run_slurm(self, filename, name): + assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' + config = self.load_config(filename, name) + slurmC = self.consume(config, 'slurm') + s_name = self.consume(slurmC, 'name') + + python_script = 'main.py' + sh_lines = self.consume(slurmC, 'sh_lines', []) + if venv := self.consume(slurmC, 'venv', False): + sh_lines += [f'source activate {venv}'] + sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] + script = " && ".join(sh_lines) + + num_jobs = 1 + + last_job_idx = num_jobs - 1 + num_parallel_jobs = min(self.consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs) + array = f'0-{last_job_idx}%{num_parallel_jobs}' + job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm']) + job_id = job.submit() + print(f'[i] Job submitted to slurm with id {job_id}') + + def run_single(self, config): + runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) + + runner = Runners[runnerName] + + with wandb.init( + project=self.consume(wandbC, 'project'), + config=config, + **wandbC + ) as run: + runner(run, config) + + assert config == {}, ('Config was not completely consumed: ', config) + + def from_args(self): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("config_file", nargs='?', default=None) + parser.add_argument("experiment", nargs='?', default='DEFAULT') + parser.add_argument("-s", "--slurm", action="store_true") + parser.add_argument("-w", "--worker", action="store_true") + parser.add_argument("-j", "--job_num", default=None) + + args = parser.parse_args() + + if args.worker: + raise Exception('Not yet implemented') + + assert args.config_file != None, 'Need to supply config file.' + if args.slurm: + self.run_slurm(args.config_file, args.experiment) + else: + self.run_local(args.config_file, args.experiment, args.job_num) + + +if __name__ == '__main__': + raise Exception('You are using it wrong...')