Added lots of things
This commit is contained in:
parent
3aab15e00e
commit
320ec6dd03
198
main.py
198
main.py
@ -7,12 +7,38 @@ from wandb.integration.sb3 import WandbCallback
|
||||
import gymnasium as gym
|
||||
import yaml
|
||||
import os
|
||||
import random
|
||||
import copy
|
||||
import collections.abc
|
||||
|
||||
import pdb
|
||||
d = pdb.set_trace
|
||||
|
||||
|
||||
try:
|
||||
import pyslurm
|
||||
except ImportError:
|
||||
slurm_avaible = False
|
||||
else:
|
||||
slurm_avaible = True
|
||||
|
||||
|
||||
PCA = None
|
||||
|
||||
|
||||
# TODO: Implement Testing
|
||||
# TODO: Implement PCA
|
||||
# TODO: Implement Slurm
|
||||
# TODO: Implement Parallel
|
||||
|
||||
def load_config(filename, name):
|
||||
config = _load_config(filename, name)
|
||||
deep_expand_vars(config, config=config)
|
||||
consume(config, 'vars', {})
|
||||
return config
|
||||
|
||||
|
||||
def _load_config(filename, name):
|
||||
with open(filename, 'r') as f:
|
||||
docs = yaml.safe_load_all(f)
|
||||
for doc in docs:
|
||||
@ -28,11 +54,12 @@ def load_config(filename, name):
|
||||
elif len(opt) == 1:
|
||||
nested_name = opt[0]
|
||||
else:
|
||||
raise Exception()
|
||||
raise Exception('Malformed import statement. Must be <import file:exp>, <import .:exp> or <import file> for file:DEFAULT.')
|
||||
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
||||
child = load_config(nested_path, nested_name)
|
||||
child = _load_config(nested_path, nested_name)
|
||||
doc = deep_update(child, doc)
|
||||
return doc
|
||||
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
@ -44,7 +71,44 @@ def deep_update(d, u):
|
||||
return d
|
||||
|
||||
|
||||
def run(filename, name):
|
||||
def expand_vars(string, **kwargs):
|
||||
if isinstance(string, str):
|
||||
return string.format(**kwargs)
|
||||
return string
|
||||
|
||||
|
||||
def apply_nested(d, f):
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
apply_nested(v, f)
|
||||
elif isinstance(v, list):
|
||||
for i, e in enumerate(v):
|
||||
apply_nested({'PTR': d[k][i]}, f)
|
||||
else:
|
||||
d[k] = f(v)
|
||||
|
||||
|
||||
def deep_expand_vars(dict, **kwargs):
|
||||
apply_nested(dict, lambda x: expand_vars(x, **kwargs))
|
||||
|
||||
|
||||
def consume(conf, keys, default=None):
|
||||
keys_arr = keys.split('.')
|
||||
if len(keys_arr) == 1:
|
||||
k = keys_arr[0]
|
||||
if default != None:
|
||||
val = conf.get(k, default)
|
||||
else:
|
||||
val = conf[k]
|
||||
if k in conf:
|
||||
del conf[k]
|
||||
return val
|
||||
child = conf[keys_arr[0]]
|
||||
child_keys = '.'.join(keys_arr[1:])
|
||||
return consume(child, child_keys, default=default)
|
||||
|
||||
|
||||
def run_local(filename, name, job_num=None):
|
||||
config = load_config(filename, name)
|
||||
if 'sweep' in config and config['sweep']['enable']:
|
||||
sweepC = config['sweep']
|
||||
@ -58,54 +122,126 @@ def run(filename, name):
|
||||
run_single(config)
|
||||
|
||||
|
||||
def run_slurm(filename, name):
|
||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
||||
config = load_config(filename, name)
|
||||
slurmC = consume(config, 'slurm')
|
||||
s_name = consume(slurmC, 'name')
|
||||
|
||||
python_script = 'main.py'
|
||||
sh_lines = consume(slurmC, 'sh_lines', [])
|
||||
if venv := consume(slurmC, 'venv', False):
|
||||
sh_lines += [f'source activate {venv}']
|
||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
||||
script = " && ".join(sh_lines)
|
||||
|
||||
num_jobs = 1
|
||||
|
||||
last_job_idx = num_jobs - 1
|
||||
num_parallel_jobs = min(consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs)
|
||||
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
||||
job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm'])
|
||||
job_id = job.submit()
|
||||
print(f'[i] Job submitted to slurm with id {job_id}')
|
||||
|
||||
|
||||
def run_single(config):
|
||||
videoC, testC, envC, algoC, pcaC = config.get('video', {}), config.get('test', {}), config.get('env', {}), config.get('algo', {}), config.get('pca', {})
|
||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
||||
|
||||
try:
|
||||
runner = Runners[runnerName]
|
||||
except:
|
||||
d()
|
||||
|
||||
with wandb.init(
|
||||
project=config['project'],
|
||||
project=consume(wandbC, 'project'),
|
||||
config=config,
|
||||
sync_tensorboard=True,
|
||||
monitor_gym=True,
|
||||
save_code=True,
|
||||
**wandbC
|
||||
) as run:
|
||||
env = DummyVecEnv([make_env_func(envC)])
|
||||
if videoC.get('enable', False):
|
||||
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
|
||||
runner(run, config)
|
||||
|
||||
assert algoC['name'] == 'PPO'
|
||||
del algoC['name']
|
||||
policy_name = algoC['policy_name']
|
||||
del algoC['policy_name']
|
||||
model = PPO(policy_name env, **algo)
|
||||
assert config == {}, ('Config was not completely consumed: ', config)
|
||||
|
||||
if pcaC.get('enable', False):
|
||||
del pcaC['enable']
|
||||
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=config["total_timesteps"],
|
||||
callback=WandbCallback()
|
||||
)
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_file", nargs='?', default=None)
|
||||
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
||||
parser.add_argument("-s", "--slurm", action="store_true")
|
||||
parser.add_argument("-w", "--worker", action="store_true")
|
||||
parser.add_argument("-j", "--job_num", default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.worker:
|
||||
raise Exception('Not yet implemented')
|
||||
|
||||
assert args.config_file != None, 'Need to supply config file.'
|
||||
if args.slurm:
|
||||
run_slurm(args.config_file, args.experiment)
|
||||
else:
|
||||
run_local(args.config_file, args.experiment, args.job_num)
|
||||
|
||||
|
||||
def debug_runner(run, config):
|
||||
print(config)
|
||||
for k in list(config.keys()):
|
||||
del config[k]
|
||||
import time
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def sb3_runner(run, config):
|
||||
videoC, testC, envC, algoC, pcaC = consume(config, 'video', {}), consume(config, 'test', {}), consume(config, 'env', {}), consume(config, 'algo', {}), consume(config, 'pca', {})
|
||||
assert config == {}
|
||||
|
||||
env = DummyVecEnv([make_env_func(envC)])
|
||||
if consume(videoC, 'enable', False):
|
||||
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
|
||||
|
||||
assert consume(algoC, 'name') == 'PPO'
|
||||
policy_name = consume(algoC, 'policy_name')
|
||||
|
||||
total_timesteps = consume(algoC, 'total_timesteps')
|
||||
|
||||
model = PPO(policy_name, env, **algoC)
|
||||
|
||||
if consume(pcaC, 'enable', False):
|
||||
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=WandbCallback()
|
||||
)
|
||||
|
||||
|
||||
def make_env_func(env_conf):
|
||||
conf = copy.deepcopy(env_conf)
|
||||
name = consume(conf, 'name')
|
||||
legacy_fancy = consume(conf, 'legacy_fancy', False)
|
||||
wrappers = consume(conf, 'wrappers', [])
|
||||
|
||||
def func():
|
||||
legacy_fancy = env_conf.get('legacy_fancy', False)
|
||||
del env_conf['name']
|
||||
if 'legacy_fancy' in env_conf:
|
||||
del env_conf['legacy_fancy']
|
||||
if legacy_fancy: # TODO: Remove when no longer needed
|
||||
fancy_gym.make(env_conf['name'], **env_conf)
|
||||
fancy_gym.make(name, **conf)
|
||||
else:
|
||||
env = gym.make(env_conf['name'], **env_conf)
|
||||
env = gym.make(name, **conf)
|
||||
|
||||
# TODO: Implement wrappers
|
||||
|
||||
env = Monitor(env)
|
||||
return env
|
||||
return func
|
||||
|
||||
|
||||
def main():
|
||||
run()
|
||||
Runners = {
|
||||
'sb3': sb3_runner,
|
||||
'debug': debug_runner
|
||||
|
||||
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user