minor changes
This commit is contained in:
parent
4b395681d0
commit
03c79e378b
46
config.yaml
46
config.yaml
@ -18,6 +18,7 @@ slurm:
|
||||
runner: debug
|
||||
|
||||
scheduler:
|
||||
total_timesteps: 10000
|
||||
repetitions: 3
|
||||
agents_per_job: 3
|
||||
reps_per_agent: 1
|
||||
@ -51,7 +52,6 @@ env:
|
||||
|
||||
algo:
|
||||
name: PPO
|
||||
total_timesteps: 10000
|
||||
policy_name: MlpPolicy
|
||||
n_steps: 4096
|
||||
vf_coef: 1.0e-5
|
||||
@ -67,15 +67,45 @@ pca:
|
||||
skip_conditioning: True
|
||||
Base_Noise: WHITE
|
||||
init_std: 1.0
|
||||
---
|
||||
|
||||
sweep:
|
||||
enable: True
|
||||
method: random,
|
||||
metric:
|
||||
goal: minimize,
|
||||
name: score
|
||||
enable: False
|
||||
method: random
|
||||
#metric:
|
||||
# goal: minimize
|
||||
# name: score
|
||||
parameters:
|
||||
lel: lol
|
||||
algo.learning_rate:
|
||||
min: 0.0001
|
||||
max: 0.1
|
||||
---
|
||||
name: sweep
|
||||
import: $
|
||||
sweep.enable: True
|
||||
scheduler.reps_per_agent: 3
|
||||
---
|
||||
name: McNamo
|
||||
import: :DEFAULT
|
||||
|
||||
video:
|
||||
length: 10
|
||||
|
||||
env:
|
||||
env_args:
|
||||
more_obs: False
|
||||
|
||||
algo.name: TRPL
|
||||
leaf: False
|
||||
---
|
||||
name: Leaf
|
||||
|
||||
vars:
|
||||
leaf: True
|
||||
---
|
||||
name: Weird
|
||||
import: :McNamo,:Leaf
|
||||
|
||||
leaf: True
|
||||
---
|
||||
ablative:
|
||||
task:
|
||||
|
47
main.py
47
main.py
@ -1,10 +1,10 @@
|
||||
import fancy_gym
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
||||
#import fancy_gym
|
||||
#from stable_baselines3 import PPO
|
||||
#from stable_baselines3.common.monitor import Monitor
|
||||
#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
||||
import wandb
|
||||
from wandb.integration.sb3 import WandbCallback
|
||||
import gymnasium as gym
|
||||
#import gymnasium as gym
|
||||
import yaml
|
||||
import os
|
||||
import random
|
||||
@ -31,22 +31,26 @@ PCA = None
|
||||
|
||||
|
||||
def load_config(filename, name):
|
||||
config = _load_config(filename, name)
|
||||
config, stack = _load_config(filename, name)
|
||||
print('[i] Merged Configs: ', stack)
|
||||
deep_expand_vars(config, config=config)
|
||||
consume(config, 'vars', {})
|
||||
return config
|
||||
|
||||
|
||||
def _load_config(filename, name):
|
||||
def _load_config(filename, name, stack=[]):
|
||||
stack.append(f'{filename}:{name}')
|
||||
with open(filename, 'r') as f:
|
||||
docs = yaml.safe_load_all(f)
|
||||
for doc in docs:
|
||||
if 'name' in doc:
|
||||
if doc['name'] == name:
|
||||
if 'import' in doc:
|
||||
imports = reversed(doc['import'].split(','))
|
||||
imports = doc['import'].split(',')
|
||||
del doc['import']
|
||||
for imp in imports:
|
||||
if imp[0] == ' ':
|
||||
imp = imp[1:]
|
||||
if imp == "$":
|
||||
imp = ':DEFAULT'
|
||||
rel_path, *opt = imp.split(':')
|
||||
@ -55,29 +59,22 @@ def _load_config(filename, name):
|
||||
elif len(opt) == 1:
|
||||
nested_name = opt[0]
|
||||
else:
|
||||
raise Exception('Malformed import statement. Must be <import file:exp>, <import .:exp> or <import file> for file:DEFAULT.')
|
||||
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
|
||||
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
||||
child = _load_config(nested_path, nested_name)
|
||||
child, stack = _load_config(nested_path, nested_name, stack=stack)
|
||||
doc = deep_update(child, doc)
|
||||
return doc
|
||||
return doc, stack
|
||||
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
||||
|
||||
|
||||
def deep_update_old(d, u):
|
||||
for k, v in u.items():
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
d[k] = deep_update_old(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
for kstr, v in u.items():
|
||||
ks = kstr.split('.')
|
||||
head = d
|
||||
for k in ks:
|
||||
last_head = head
|
||||
if k not in head:
|
||||
head[k] = {}
|
||||
head = head[k]
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
|
||||
@ -107,8 +104,8 @@ def deep_expand_vars(dict, **kwargs):
|
||||
apply_nested(dict, lambda x: expand_vars(x, **kwargs))
|
||||
|
||||
|
||||
def consume(conf, keys, default=None):
|
||||
keys_arr = keys.split('.')
|
||||
def consume(conf, key, default=None):
|
||||
keys_arr = key.split('.')
|
||||
if len(keys_arr) == 1:
|
||||
k = keys_arr[0]
|
||||
if default != None:
|
||||
@ -133,7 +130,7 @@ def run_local(filename, name, job_num=None):
|
||||
project=project
|
||||
)
|
||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
||||
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent'])
|
||||
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
|
||||
else:
|
||||
consume(config, 'sweep', {})
|
||||
run_single(config)
|
||||
@ -232,7 +229,7 @@ def sb3_runner(run, config):
|
||||
assert consume(algoC, 'name') == 'PPO'
|
||||
policy_name = consume(algoC, 'policy_name')
|
||||
|
||||
total_timesteps = consume(algoC, 'total_timesteps')
|
||||
total_timesteps = config.get('run', {}).get('total_timesteps', {})
|
||||
|
||||
model = PPO(policy_name, env, **algoC)
|
||||
|
||||
@ -267,8 +264,6 @@ def make_env_func(env_conf):
|
||||
Runners = {
|
||||
'sb3': sb3_runner,
|
||||
'debug': debug_runner
|
||||
|
||||
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user