minor changes

This commit is contained in:
Dominik Moritz Roth 2023-07-05 21:18:57 +02:00
parent 4b395681d0
commit 03c79e378b
2 changed files with 59 additions and 34 deletions

View File

@ -18,6 +18,7 @@ slurm:
runner: debug
scheduler:
total_timesteps: 10000
repetitions: 3
agents_per_job: 3
reps_per_agent: 1
@ -51,7 +52,6 @@ env:
algo:
name: PPO
total_timesteps: 10000
policy_name: MlpPolicy
n_steps: 4096
vf_coef: 1.0e-5
@ -67,15 +67,45 @@ pca:
skip_conditioning: True
Base_Noise: WHITE
init_std: 1.0
---
sweep:
enable: True
method: random,
metric:
goal: minimize,
name: score
enable: False
method: random
#metric:
# goal: minimize
# name: score
parameters:
lel: lol
algo.learning_rate:
min: 0.0001
max: 0.1
---
name: sweep
import: $
sweep.enable: True
scheduler.reps_per_agent: 3
---
name: McNamo
import: :DEFAULT
video:
length: 10
env:
env_args:
more_obs: False
algo.name: TRPL
leaf: False
---
name: Leaf
vars:
leaf: True
---
name: Weird
import: :McNamo,:Leaf
leaf: True
---
ablative:
task:

47
main.py
View File

@ -1,10 +1,10 @@
import fancy_gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
#import fancy_gym
#from stable_baselines3 import PPO
#from stable_baselines3.common.monitor import Monitor
#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback
import gymnasium as gym
#import gymnasium as gym
import yaml
import os
import random
@ -31,22 +31,26 @@ PCA = None
def load_config(filename, name):
config = _load_config(filename, name)
config, stack = _load_config(filename, name)
print('[i] Merged Configs: ', stack)
deep_expand_vars(config, config=config)
consume(config, 'vars', {})
return config
def _load_config(filename, name):
def _load_config(filename, name, stack=[]):
stack.append(f'{filename}:{name}')
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
imports = reversed(doc['import'].split(','))
imports = doc['import'].split(',')
del doc['import']
for imp in imports:
if imp[0] == ' ':
imp = imp[1:]
if imp == "$":
imp = ':DEFAULT'
rel_path, *opt = imp.split(':')
@ -55,29 +59,22 @@ def _load_config(filename, name):
elif len(opt) == 1:
nested_name = opt[0]
else:
raise Exception('Malformed import statement. Must be <import file:exp>, <import .:exp> or <import file> for file:DEFAULT.')
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child = _load_config(nested_path, nested_name)
child, stack = _load_config(nested_path, nested_name, stack=stack)
doc = deep_update(child, doc)
return doc
return doc, stack
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update_old(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update_old(d.get(k, {}), v)
else:
d[k] = v
return d
def deep_update(d, u):
for kstr, v in u.items():
ks = kstr.split('.')
head = d
for k in ks:
last_head = head
if k not in head:
head[k] = {}
head = head[k]
if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
@ -107,8 +104,8 @@ def deep_expand_vars(dict, **kwargs):
apply_nested(dict, lambda x: expand_vars(x, **kwargs))
def consume(conf, keys, default=None):
keys_arr = keys.split('.')
def consume(conf, key, default=None):
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
if default != None:
@ -133,7 +130,7 @@ def run_local(filename, name, job_num=None):
project=project
)
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent'])
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
else:
consume(config, 'sweep', {})
run_single(config)
@ -232,7 +229,7 @@ def sb3_runner(run, config):
assert consume(algoC, 'name') == 'PPO'
policy_name = consume(algoC, 'policy_name')
total_timesteps = consume(algoC, 'total_timesteps')
total_timesteps = config.get('run', {}).get('total_timesteps', {})
model = PPO(policy_name, env, **algoC)
@ -267,8 +264,6 @@ def make_env_func(env_conf):
Runners = {
'sb3': sb3_runner,
'debug': debug_runner
}
if __name__ == '__main__':