Slate/main.py

271 lines
8.1 KiB
Python
Raw Normal View History

2023-07-05 21:18:57 +02:00
#import fancy_gym
#from stable_baselines3 import PPO
#from stable_baselines3.common.monitor import Monitor
#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
2023-07-05 15:02:53 +02:00
import wandb
from wandb.integration.sb3 import WandbCallback
2023-07-05 21:18:57 +02:00
#import gymnasium as gym
2023-07-05 15:02:53 +02:00
import yaml
import os
2023-07-05 19:29:21 +02:00
import random
import copy
2023-07-05 15:02:53 +02:00
import collections.abc
2023-07-05 20:30:57 +02:00
from functools import partial
2023-07-05 15:02:53 +02:00
2023-07-05 19:29:21 +02:00
import pdb
d = pdb.set_trace
try:
import pyslurm
except ImportError:
slurm_avaible = False
else:
slurm_avaible = True
2023-07-05 15:02:53 +02:00
PCA = None
2023-07-05 19:29:21 +02:00
# TODO: Implement Testing
# TODO: Implement PCA
# TODO: Implement Slurm
# TODO: Implement Parallel
2023-07-05 20:30:57 +02:00
2023-07-05 15:02:53 +02:00
def load_config(filename, name):
2023-07-05 21:18:57 +02:00
config, stack = _load_config(filename, name)
print('[i] Merged Configs: ', stack)
2023-07-05 19:29:21 +02:00
deep_expand_vars(config, config=config)
consume(config, 'vars', {})
return config
2023-07-05 21:18:57 +02:00
def _load_config(filename, name, stack=[]):
stack.append(f'{filename}:{name}')
2023-07-05 15:02:53 +02:00
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
2023-07-05 21:18:57 +02:00
imports = doc['import'].split(',')
2023-07-05 15:02:53 +02:00
del doc['import']
for imp in imports:
2023-07-05 21:18:57 +02:00
if imp[0] == ' ':
imp = imp[1:]
2023-07-05 20:30:57 +02:00
if imp == "$":
imp = ':DEFAULT'
2023-07-05 15:02:53 +02:00
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
elif len(opt) == 1:
nested_name = opt[0]
else:
2023-07-05 21:18:57 +02:00
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
2023-07-05 15:02:53 +02:00
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
2023-07-05 21:18:57 +02:00
child, stack = _load_config(nested_path, nested_name, stack=stack)
2023-07-05 15:02:53 +02:00
doc = deep_update(child, doc)
2023-07-05 21:18:57 +02:00
return doc, stack
2023-07-05 19:29:21 +02:00
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
2023-07-05 15:02:53 +02:00
2023-07-05 20:30:57 +02:00
def deep_update(d, u):
for kstr, v in u.items():
ks = kstr.split('.')
head = d
for k in ks:
last_head = head
2023-07-05 21:18:57 +02:00
if k not in head:
head[k] = {}
2023-07-05 20:30:57 +02:00
head = head[k]
if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
else:
last_head[ks[-1]] = v
return d
2023-07-05 19:29:21 +02:00
def expand_vars(string, **kwargs):
if isinstance(string, str):
return string.format(**kwargs)
return string
def apply_nested(d, f):
for k, v in d.items():
if isinstance(v, dict):
apply_nested(v, f)
elif isinstance(v, list):
for i, e in enumerate(v):
apply_nested({'PTR': d[k][i]}, f)
else:
d[k] = f(v)
def deep_expand_vars(dict, **kwargs):
apply_nested(dict, lambda x: expand_vars(x, **kwargs))
2023-07-05 21:18:57 +02:00
def consume(conf, key, default=None):
keys_arr = key.split('.')
2023-07-05 19:29:21 +02:00
if len(keys_arr) == 1:
k = keys_arr[0]
if default != None:
val = conf.get(k, default)
else:
val = conf[k]
if k in conf:
del conf[k]
return val
2023-07-05 20:30:57 +02:00
child = conf.get(keys_arr[0], {})
2023-07-05 19:29:21 +02:00
child_keys = '.'.join(keys_arr[1:])
return consume(child, child_keys, default=default)
def run_local(filename, name, job_num=None):
2023-07-05 15:02:53 +02:00
config = load_config(filename, name)
2023-07-05 20:30:57 +02:00
if consume(config, 'sweep.enable', False):
sweepC = consume(config, 'sweep')
project = consume(config, 'wandb.project')
2023-07-05 15:02:53 +02:00
sweep_id = wandb.sweep(
sweep=sweepC,
2023-07-05 20:30:57 +02:00
project=project
2023-07-05 15:02:53 +02:00
)
2023-07-05 20:30:57 +02:00
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
2023-07-05 21:18:57 +02:00
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
2023-07-05 15:02:53 +02:00
else:
2023-07-05 20:30:57 +02:00
consume(config, 'sweep', {})
2023-07-05 15:02:53 +02:00
run_single(config)
2023-07-05 20:30:57 +02:00
def run_from_sweep(orig_config, runnerName, project, wandbC):
runner = Runners[runnerName]
with wandb.init(
project=project,
**wandbC
) as run:
config = copy.deepcopy(orig_config)
deep_update(config, wandb.config)
runner(run, config)
assert config == {}, ('Config was not completely consumed: ', config)
2023-07-05 19:29:21 +02:00
def run_slurm(filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = load_config(filename, name)
slurmC = consume(config, 'slurm')
s_name = consume(slurmC, 'name')
python_script = 'main.py'
sh_lines = consume(slurmC, 'sh_lines', [])
if venv := consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
script = " && ".join(sh_lines)
num_jobs = 1
last_job_idx = num_jobs - 1
num_parallel_jobs = min(consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs)
array = f'0-{last_job_idx}%{num_parallel_jobs}'
job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm'])
job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}')
2023-07-05 15:02:53 +02:00
def run_single(config):
2023-07-05 19:29:21 +02:00
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
2023-07-05 20:30:57 +02:00
runner = Runners[runnerName]
2023-07-05 15:02:53 +02:00
with wandb.init(
2023-07-05 19:29:21 +02:00
project=consume(wandbC, 'project'),
2023-07-05 15:02:53 +02:00
config=config,
2023-07-05 19:29:21 +02:00
**wandbC
2023-07-05 15:02:53 +02:00
) as run:
2023-07-05 19:29:21 +02:00
runner(run, config)
assert config == {}, ('Config was not completely consumed: ', config)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config_file", nargs='?', default=None)
parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_num", default=None)
args = parser.parse_args()
if args.worker:
raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.'
if args.slurm:
run_slurm(args.config_file, args.experiment)
else:
run_local(args.config_file, args.experiment, args.job_num)
def debug_runner(run, config):
print(config)
for k in list(config.keys()):
del config[k]
import time
time.sleep(10)
def sb3_runner(run, config):
videoC, testC, envC, algoC, pcaC = consume(config, 'video', {}), consume(config, 'test', {}), consume(config, 'env', {}), consume(config, 'algo', {}), consume(config, 'pca', {})
assert config == {}
env = DummyVecEnv([make_env_func(envC)])
if consume(videoC, 'enable', False):
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
assert consume(algoC, 'name') == 'PPO'
policy_name = consume(algoC, 'policy_name')
2023-07-05 21:18:57 +02:00
total_timesteps = config.get('run', {}).get('total_timesteps', {})
2023-07-05 19:29:21 +02:00
model = PPO(policy_name, env, **algoC)
if consume(pcaC, 'enable', False):
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
model.learn(
total_timesteps=total_timesteps,
callback=WandbCallback()
)
2023-07-05 15:02:53 +02:00
def make_env_func(env_conf):
2023-07-05 19:29:21 +02:00
conf = copy.deepcopy(env_conf)
name = consume(conf, 'name')
legacy_fancy = consume(conf, 'legacy_fancy', False)
wrappers = consume(conf, 'wrappers', [])
2023-07-05 15:02:53 +02:00
def func():
if legacy_fancy: # TODO: Remove when no longer needed
2023-07-05 19:29:21 +02:00
fancy_gym.make(name, **conf)
2023-07-05 15:02:53 +02:00
else:
2023-07-05 19:29:21 +02:00
env = gym.make(name, **conf)
# TODO: Implement wrappers
2023-07-05 15:02:53 +02:00
env = Monitor(env)
return env
return func
2023-07-05 19:29:21 +02:00
Runners = {
'sb3': sb3_runner,
'debug': debug_runner
}
2023-07-05 15:02:53 +02:00
if __name__ == '__main__':
main()