minor changes

This commit is contained in:
Dominik Moritz Roth 2023-07-05 21:18:57 +02:00
parent 4b395681d0
commit 03c79e378b
2 changed files with 59 additions and 34 deletions

View File

@ -18,6 +18,7 @@ slurm:
runner: debug runner: debug
scheduler: scheduler:
total_timesteps: 10000
repetitions: 3 repetitions: 3
agents_per_job: 3 agents_per_job: 3
reps_per_agent: 1 reps_per_agent: 1
@ -51,7 +52,6 @@ env:
algo: algo:
name: PPO name: PPO
total_timesteps: 10000
policy_name: MlpPolicy policy_name: MlpPolicy
n_steps: 4096 n_steps: 4096
vf_coef: 1.0e-5 vf_coef: 1.0e-5
@ -67,15 +67,45 @@ pca:
skip_conditioning: True skip_conditioning: True
Base_Noise: WHITE Base_Noise: WHITE
init_std: 1.0 init_std: 1.0
---
sweep: sweep:
enable: True enable: False
method: random, method: random
metric: #metric:
goal: minimize, # goal: minimize
name: score # name: score
parameters: parameters:
lel: lol algo.learning_rate:
min: 0.0001
max: 0.1
---
name: sweep
import: $
sweep.enable: True
scheduler.reps_per_agent: 3
---
name: McNamo
import: :DEFAULT
video:
length: 10
env:
env_args:
more_obs: False
algo.name: TRPL
leaf: False
---
name: Leaf
vars:
leaf: True
---
name: Weird
import: :McNamo,:Leaf
leaf: True
--- ---
ablative: ablative:
task: task:

47
main.py
View File

@ -1,10 +1,10 @@
import fancy_gym #import fancy_gym
from stable_baselines3 import PPO #from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor #from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder #from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb import wandb
from wandb.integration.sb3 import WandbCallback from wandb.integration.sb3 import WandbCallback
import gymnasium as gym #import gymnasium as gym
import yaml import yaml
import os import os
import random import random
@ -31,22 +31,26 @@ PCA = None
def load_config(filename, name): def load_config(filename, name):
config = _load_config(filename, name) config, stack = _load_config(filename, name)
print('[i] Merged Configs: ', stack)
deep_expand_vars(config, config=config) deep_expand_vars(config, config=config)
consume(config, 'vars', {}) consume(config, 'vars', {})
return config return config
def _load_config(filename, name): def _load_config(filename, name, stack=[]):
stack.append(f'{filename}:{name}')
with open(filename, 'r') as f: with open(filename, 'r') as f:
docs = yaml.safe_load_all(f) docs = yaml.safe_load_all(f)
for doc in docs: for doc in docs:
if 'name' in doc: if 'name' in doc:
if doc['name'] == name: if doc['name'] == name:
if 'import' in doc: if 'import' in doc:
imports = reversed(doc['import'].split(',')) imports = doc['import'].split(',')
del doc['import'] del doc['import']
for imp in imports: for imp in imports:
if imp[0] == ' ':
imp = imp[1:]
if imp == "$": if imp == "$":
imp = ':DEFAULT' imp = ':DEFAULT'
rel_path, *opt = imp.split(':') rel_path, *opt = imp.split(':')
@ -55,29 +59,22 @@ def _load_config(filename, name):
elif len(opt) == 1: elif len(opt) == 1:
nested_name = opt[0] nested_name = opt[0]
else: else:
raise Exception('Malformed import statement. Must be <import file:exp>, <import .:exp> or <import file> for file:DEFAULT.') raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child = _load_config(nested_path, nested_name) child, stack = _load_config(nested_path, nested_name, stack=stack)
doc = deep_update(child, doc) doc = deep_update(child, doc)
return doc return doc, stack
raise Exception(f'Unable to find experiment <{name}> in <{filename}>') raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update_old(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update_old(d.get(k, {}), v)
else:
d[k] = v
return d
def deep_update(d, u): def deep_update(d, u):
for kstr, v in u.items(): for kstr, v in u.items():
ks = kstr.split('.') ks = kstr.split('.')
head = d head = d
for k in ks: for k in ks:
last_head = head last_head = head
if k not in head:
head[k] = {}
head = head[k] head = head[k]
if isinstance(v, collections.abc.Mapping): if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = deep_update(d.get(k, {}), v) last_head[ks[-1]] = deep_update(d.get(k, {}), v)
@ -107,8 +104,8 @@ def deep_expand_vars(dict, **kwargs):
apply_nested(dict, lambda x: expand_vars(x, **kwargs)) apply_nested(dict, lambda x: expand_vars(x, **kwargs))
def consume(conf, keys, default=None): def consume(conf, key, default=None):
keys_arr = keys.split('.') keys_arr = key.split('.')
if len(keys_arr) == 1: if len(keys_arr) == 1:
k = keys_arr[0] k = keys_arr[0]
if default != None: if default != None:
@ -133,7 +130,7 @@ def run_local(filename, name, job_num=None):
project=project project=project
) )
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent']) wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
else: else:
consume(config, 'sweep', {}) consume(config, 'sweep', {})
run_single(config) run_single(config)
@ -232,7 +229,7 @@ def sb3_runner(run, config):
assert consume(algoC, 'name') == 'PPO' assert consume(algoC, 'name') == 'PPO'
policy_name = consume(algoC, 'policy_name') policy_name = consume(algoC, 'policy_name')
total_timesteps = consume(algoC, 'total_timesteps') total_timesteps = config.get('run', {}).get('total_timesteps', {})
model = PPO(policy_name, env, **algoC) model = PPO(policy_name, env, **algoC)
@ -267,8 +264,6 @@ def make_env_func(env_conf):
Runners = { Runners = {
'sb3': sb3_runner, 'sb3': sb3_runner,
'debug': debug_runner 'debug': debug_runner
} }
if __name__ == '__main__': if __name__ == '__main__':