diff --git a/main.py b/main.py index 469b30a..896f45f 100644 --- a/main.py +++ b/main.py @@ -10,11 +10,11 @@ import os import random import copy import collections.abc +from functools import partial import pdb d = pdb.set_trace - try: import pyslurm except ImportError: @@ -22,15 +22,14 @@ except ImportError: else: slurm_avaible = True - PCA = None - # TODO: Implement Testing # TODO: Implement PCA # TODO: Implement Slurm # TODO: Implement Parallel + def load_config(filename, name): config = _load_config(filename, name) deep_expand_vars(config, config=config) @@ -48,6 +47,8 @@ def _load_config(filename, name): imports = reversed(doc['import'].split(',')) del doc['import'] for imp in imports: + if imp == "$": + imp = ':DEFAULT' rel_path, *opt = imp.split(':') if len(opt) == 0: nested_name = 'DEFAULT' @@ -62,15 +63,29 @@ def _load_config(filename, name): raise Exception(f'Unable to find experiment <{name}> in <{filename}>') -def deep_update(d, u): +def deep_update_old(d, u): for k, v in u.items(): if isinstance(v, collections.abc.Mapping): - d[k] = deep_update(d.get(k, {}), v) + d[k] = deep_update_old(d.get(k, {}), v) else: d[k] = v return d +def deep_update(d, u): + for kstr, v in u.items(): + ks = kstr.split('.') + head = d + for k in ks: + last_head = head + head = head[k] + if isinstance(v, collections.abc.Mapping): + last_head[ks[-1]] = deep_update(d.get(k, {}), v) + else: + last_head[ks[-1]] = v + return d + + def expand_vars(string, **kwargs): if isinstance(string, str): return string.format(**kwargs) @@ -103,25 +118,41 @@ def consume(conf, keys, default=None): if k in conf: del conf[k] return val - child = conf[keys_arr[0]] + child = conf.get(keys_arr[0], {}) child_keys = '.'.join(keys_arr[1:]) return consume(child, child_keys, default=default) def run_local(filename, name, job_num=None): config = load_config(filename, name) - if 'sweep' in config and config['sweep']['enable']: - sweepC = config['sweep'] - del sweepC['enable'] + if consume(config, 'sweep.enable', False): + sweepC = consume(config, 'sweep') + project = consume(config, 'wandb.project') sweep_id = wandb.sweep( sweep=sweepC, - project=config['project'] + project=project ) - wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent']) + runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) + wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent']) else: + consume(config, 'sweep', {}) run_single(config) +def run_from_sweep(orig_config, runnerName, project, wandbC): + runner = Runners[runnerName] + + with wandb.init( + project=project, + **wandbC + ) as run: + config = copy.deepcopy(orig_config) + deep_update(config, wandb.config) + runner(run, config) + + assert config == {}, ('Config was not completely consumed: ', config) + + def run_slurm(filename, name): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = load_config(filename, name) @@ -148,10 +179,7 @@ def run_slurm(filename, name): def run_single(config): runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {}) - try: - runner = Runners[runnerName] - except: - d() + runner = Runners[runnerName] with wandb.init( project=consume(wandbC, 'project'),