minor changes
This commit is contained in:
parent
4b395681d0
commit
03c79e378b
46
config.yaml
46
config.yaml
@ -18,6 +18,7 @@ slurm:
|
|||||||
runner: debug
|
runner: debug
|
||||||
|
|
||||||
scheduler:
|
scheduler:
|
||||||
|
total_timesteps: 10000
|
||||||
repetitions: 3
|
repetitions: 3
|
||||||
agents_per_job: 3
|
agents_per_job: 3
|
||||||
reps_per_agent: 1
|
reps_per_agent: 1
|
||||||
@ -51,7 +52,6 @@ env:
|
|||||||
|
|
||||||
algo:
|
algo:
|
||||||
name: PPO
|
name: PPO
|
||||||
total_timesteps: 10000
|
|
||||||
policy_name: MlpPolicy
|
policy_name: MlpPolicy
|
||||||
n_steps: 4096
|
n_steps: 4096
|
||||||
vf_coef: 1.0e-5
|
vf_coef: 1.0e-5
|
||||||
@ -67,15 +67,45 @@ pca:
|
|||||||
skip_conditioning: True
|
skip_conditioning: True
|
||||||
Base_Noise: WHITE
|
Base_Noise: WHITE
|
||||||
init_std: 1.0
|
init_std: 1.0
|
||||||
---
|
|
||||||
sweep:
|
sweep:
|
||||||
enable: True
|
enable: False
|
||||||
method: random,
|
method: random
|
||||||
metric:
|
#metric:
|
||||||
goal: minimize,
|
# goal: minimize
|
||||||
name: score
|
# name: score
|
||||||
parameters:
|
parameters:
|
||||||
lel: lol
|
algo.learning_rate:
|
||||||
|
min: 0.0001
|
||||||
|
max: 0.1
|
||||||
|
---
|
||||||
|
name: sweep
|
||||||
|
import: $
|
||||||
|
sweep.enable: True
|
||||||
|
scheduler.reps_per_agent: 3
|
||||||
|
---
|
||||||
|
name: McNamo
|
||||||
|
import: :DEFAULT
|
||||||
|
|
||||||
|
video:
|
||||||
|
length: 10
|
||||||
|
|
||||||
|
env:
|
||||||
|
env_args:
|
||||||
|
more_obs: False
|
||||||
|
|
||||||
|
algo.name: TRPL
|
||||||
|
leaf: False
|
||||||
|
---
|
||||||
|
name: Leaf
|
||||||
|
|
||||||
|
vars:
|
||||||
|
leaf: True
|
||||||
|
---
|
||||||
|
name: Weird
|
||||||
|
import: :McNamo,:Leaf
|
||||||
|
|
||||||
|
leaf: True
|
||||||
---
|
---
|
||||||
ablative:
|
ablative:
|
||||||
task:
|
task:
|
||||||
|
47
main.py
47
main.py
@ -1,10 +1,10 @@
|
|||||||
import fancy_gym
|
#import fancy_gym
|
||||||
from stable_baselines3 import PPO
|
#from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.monitor import Monitor
|
#from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
||||||
import wandb
|
import wandb
|
||||||
from wandb.integration.sb3 import WandbCallback
|
from wandb.integration.sb3 import WandbCallback
|
||||||
import gymnasium as gym
|
#import gymnasium as gym
|
||||||
import yaml
|
import yaml
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -31,22 +31,26 @@ PCA = None
|
|||||||
|
|
||||||
|
|
||||||
def load_config(filename, name):
|
def load_config(filename, name):
|
||||||
config = _load_config(filename, name)
|
config, stack = _load_config(filename, name)
|
||||||
|
print('[i] Merged Configs: ', stack)
|
||||||
deep_expand_vars(config, config=config)
|
deep_expand_vars(config, config=config)
|
||||||
consume(config, 'vars', {})
|
consume(config, 'vars', {})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _load_config(filename, name):
|
def _load_config(filename, name, stack=[]):
|
||||||
|
stack.append(f'{filename}:{name}')
|
||||||
with open(filename, 'r') as f:
|
with open(filename, 'r') as f:
|
||||||
docs = yaml.safe_load_all(f)
|
docs = yaml.safe_load_all(f)
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
if 'name' in doc:
|
if 'name' in doc:
|
||||||
if doc['name'] == name:
|
if doc['name'] == name:
|
||||||
if 'import' in doc:
|
if 'import' in doc:
|
||||||
imports = reversed(doc['import'].split(','))
|
imports = doc['import'].split(',')
|
||||||
del doc['import']
|
del doc['import']
|
||||||
for imp in imports:
|
for imp in imports:
|
||||||
|
if imp[0] == ' ':
|
||||||
|
imp = imp[1:]
|
||||||
if imp == "$":
|
if imp == "$":
|
||||||
imp = ':DEFAULT'
|
imp = ':DEFAULT'
|
||||||
rel_path, *opt = imp.split(':')
|
rel_path, *opt = imp.split(':')
|
||||||
@ -55,29 +59,22 @@ def _load_config(filename, name):
|
|||||||
elif len(opt) == 1:
|
elif len(opt) == 1:
|
||||||
nested_name = opt[0]
|
nested_name = opt[0]
|
||||||
else:
|
else:
|
||||||
raise Exception('Malformed import statement. Must be <import file:exp>, <import .:exp> or <import file> for file:DEFAULT.')
|
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
|
||||||
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
||||||
child = _load_config(nested_path, nested_name)
|
child, stack = _load_config(nested_path, nested_name, stack=stack)
|
||||||
doc = deep_update(child, doc)
|
doc = deep_update(child, doc)
|
||||||
return doc
|
return doc, stack
|
||||||
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
||||||
|
|
||||||
|
|
||||||
def deep_update_old(d, u):
|
|
||||||
for k, v in u.items():
|
|
||||||
if isinstance(v, collections.abc.Mapping):
|
|
||||||
d[k] = deep_update_old(d.get(k, {}), v)
|
|
||||||
else:
|
|
||||||
d[k] = v
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def deep_update(d, u):
|
def deep_update(d, u):
|
||||||
for kstr, v in u.items():
|
for kstr, v in u.items():
|
||||||
ks = kstr.split('.')
|
ks = kstr.split('.')
|
||||||
head = d
|
head = d
|
||||||
for k in ks:
|
for k in ks:
|
||||||
last_head = head
|
last_head = head
|
||||||
|
if k not in head:
|
||||||
|
head[k] = {}
|
||||||
head = head[k]
|
head = head[k]
|
||||||
if isinstance(v, collections.abc.Mapping):
|
if isinstance(v, collections.abc.Mapping):
|
||||||
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
|
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
|
||||||
@ -107,8 +104,8 @@ def deep_expand_vars(dict, **kwargs):
|
|||||||
apply_nested(dict, lambda x: expand_vars(x, **kwargs))
|
apply_nested(dict, lambda x: expand_vars(x, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
def consume(conf, keys, default=None):
|
def consume(conf, key, default=None):
|
||||||
keys_arr = keys.split('.')
|
keys_arr = key.split('.')
|
||||||
if len(keys_arr) == 1:
|
if len(keys_arr) == 1:
|
||||||
k = keys_arr[0]
|
k = keys_arr[0]
|
||||||
if default != None:
|
if default != None:
|
||||||
@ -133,7 +130,7 @@ def run_local(filename, name, job_num=None):
|
|||||||
project=project
|
project=project
|
||||||
)
|
)
|
||||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
||||||
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent'])
|
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
|
||||||
else:
|
else:
|
||||||
consume(config, 'sweep', {})
|
consume(config, 'sweep', {})
|
||||||
run_single(config)
|
run_single(config)
|
||||||
@ -232,7 +229,7 @@ def sb3_runner(run, config):
|
|||||||
assert consume(algoC, 'name') == 'PPO'
|
assert consume(algoC, 'name') == 'PPO'
|
||||||
policy_name = consume(algoC, 'policy_name')
|
policy_name = consume(algoC, 'policy_name')
|
||||||
|
|
||||||
total_timesteps = consume(algoC, 'total_timesteps')
|
total_timesteps = config.get('run', {}).get('total_timesteps', {})
|
||||||
|
|
||||||
model = PPO(policy_name, env, **algoC)
|
model = PPO(policy_name, env, **algoC)
|
||||||
|
|
||||||
@ -267,8 +264,6 @@ def make_env_func(env_conf):
|
|||||||
Runners = {
|
Runners = {
|
||||||
'sb3': sb3_runner,
|
'sb3': sb3_runner,
|
||||||
'debug': debug_runner
|
'debug': debug_runner
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user