From 03c79e378b97411cb1b79b2e443c020da1f55f72 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 5 Jul 2023 21:18:57 +0200 Subject: [PATCH] minor changes --- config.yaml | 46 ++++++++++++++++++++++++++++++++++++++-------- main.py | 47 +++++++++++++++++++++-------------------------- 2 files changed, 59 insertions(+), 34 deletions(-) diff --git a/config.yaml b/config.yaml index 74c9fb9..ca5526f 100644 --- a/config.yaml +++ b/config.yaml @@ -18,6 +18,7 @@ slurm: runner: debug scheduler: + total_timesteps: 10000 repetitions: 3 agents_per_job: 3 reps_per_agent: 1 @@ -51,7 +52,6 @@ env: algo: name: PPO - total_timesteps: 10000 policy_name: MlpPolicy n_steps: 4096 vf_coef: 1.0e-5 @@ -67,15 +67,45 @@ pca: skip_conditioning: True Base_Noise: WHITE init_std: 1.0 ---- + sweep: - enable: True - method: random, - metric: - goal: minimize, - name: score + enable: False + method: random + #metric: + # goal: minimize + # name: score parameters: - lel: lol + algo.learning_rate: + min: 0.0001 + max: 0.1 +--- +name: sweep +import: $ +sweep.enable: True +scheduler.reps_per_agent: 3 +--- +name: McNamo +import: :DEFAULT + +video: + length: 10 + +env: + env_args: + more_obs: False + +algo.name: TRPL +leaf: False +--- +name: Leaf + +vars: + leaf: True +--- +name: Weird +import: :McNamo,:Leaf + +leaf: True --- ablative: task: diff --git a/main.py b/main.py index 896f45f..bc7b762 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,10 @@ -import fancy_gym -from stable_baselines3 import PPO -from stable_baselines3.common.monitor import Monitor -from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +#import fancy_gym +#from stable_baselines3 import PPO +#from stable_baselines3.common.monitor import Monitor +#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder import wandb from wandb.integration.sb3 import WandbCallback -import gymnasium as gym +#import gymnasium as gym import yaml import os import random @@ -31,22 +31,26 @@ PCA = None def load_config(filename, name): - config = _load_config(filename, name) + config, stack = _load_config(filename, name) + print('[i] Merged Configs: ', stack) deep_expand_vars(config, config=config) consume(config, 'vars', {}) return config -def _load_config(filename, name): +def _load_config(filename, name, stack=[]): + stack.append(f'{filename}:{name}') with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: if 'name' in doc: if doc['name'] == name: if 'import' in doc: - imports = reversed(doc['import'].split(',')) + imports = doc['import'].split(',') del doc['import'] for imp in imports: + if imp[0] == ' ': + imp = imp[1:] if imp == "$": imp = ':DEFAULT' rel_path, *opt = imp.split(':') @@ -55,29 +59,22 @@ def _load_config(filename, name): elif len(opt) == 1: nested_name = opt[0] else: - raise Exception('Malformed import statement. Must be , or for file:DEFAULT.') + raise Exception('Malformed import statement. Must be , , for file:DEFAULT or for :DEFAULT.') nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename - child = _load_config(nested_path, nested_name) + child, stack = _load_config(nested_path, nested_name, stack=stack) doc = deep_update(child, doc) - return doc + return doc, stack raise Exception(f'Unable to find experiment <{name}> in <{filename}>') -def deep_update_old(d, u): - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = deep_update_old(d.get(k, {}), v) - else: - d[k] = v - return d - - def deep_update(d, u): for kstr, v in u.items(): ks = kstr.split('.') head = d for k in ks: last_head = head + if k not in head: + head[k] = {} head = head[k] if isinstance(v, collections.abc.Mapping): last_head[ks[-1]] = deep_update(d.get(k, {}), v) @@ -107,8 +104,8 @@ def deep_expand_vars(dict, **kwargs): apply_nested(dict, lambda x: expand_vars(x, **kwargs)) -def consume(conf, keys, default=None): - keys_arr = keys.split('.') +def consume(conf, key, default=None): + keys_arr = key.split('.') if len(keys_arr) == 1: k = keys_arr[0] if default != None: @@ -133,7 +130,7 @@ def run_local(filename, name, job_num=None): project=project ) runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) - wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent']) + wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent']) else: consume(config, 'sweep', {}) run_single(config) @@ -232,7 +229,7 @@ def sb3_runner(run, config): assert consume(algoC, 'name') == 'PPO' policy_name = consume(algoC, 'policy_name') - total_timesteps = consume(algoC, 'total_timesteps') + total_timesteps = config.get('run', {}).get('total_timesteps', {}) model = PPO(policy_name, env, **algoC) @@ -267,8 +264,6 @@ def make_env_func(env_conf): Runners = { 'sb3': sb3_runner, 'debug': debug_runner - - } if __name__ == '__main__':