Sweeps should now be functional
This commit is contained in:
parent
b7ad796671
commit
4b395681d0
58
main.py
58
main.py
@ -10,11 +10,11 @@ import os
|
||||
import random
|
||||
import copy
|
||||
import collections.abc
|
||||
from functools import partial
|
||||
|
||||
import pdb
|
||||
d = pdb.set_trace
|
||||
|
||||
|
||||
try:
|
||||
import pyslurm
|
||||
except ImportError:
|
||||
@ -22,15 +22,14 @@ except ImportError:
|
||||
else:
|
||||
slurm_avaible = True
|
||||
|
||||
|
||||
PCA = None
|
||||
|
||||
|
||||
# TODO: Implement Testing
|
||||
# TODO: Implement PCA
|
||||
# TODO: Implement Slurm
|
||||
# TODO: Implement Parallel
|
||||
|
||||
|
||||
def load_config(filename, name):
|
||||
config = _load_config(filename, name)
|
||||
deep_expand_vars(config, config=config)
|
||||
@ -48,6 +47,8 @@ def _load_config(filename, name):
|
||||
imports = reversed(doc['import'].split(','))
|
||||
del doc['import']
|
||||
for imp in imports:
|
||||
if imp == "$":
|
||||
imp = ':DEFAULT'
|
||||
rel_path, *opt = imp.split(':')
|
||||
if len(opt) == 0:
|
||||
nested_name = 'DEFAULT'
|
||||
@ -62,15 +63,29 @@ def _load_config(filename, name):
|
||||
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
def deep_update_old(d, u):
|
||||
for k, v in u.items():
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
d[k] = deep_update(d.get(k, {}), v)
|
||||
d[k] = deep_update_old(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
for kstr, v in u.items():
|
||||
ks = kstr.split('.')
|
||||
head = d
|
||||
for k in ks:
|
||||
last_head = head
|
||||
head = head[k]
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
|
||||
else:
|
||||
last_head[ks[-1]] = v
|
||||
return d
|
||||
|
||||
|
||||
def expand_vars(string, **kwargs):
|
||||
if isinstance(string, str):
|
||||
return string.format(**kwargs)
|
||||
@ -103,25 +118,41 @@ def consume(conf, keys, default=None):
|
||||
if k in conf:
|
||||
del conf[k]
|
||||
return val
|
||||
child = conf[keys_arr[0]]
|
||||
child = conf.get(keys_arr[0], {})
|
||||
child_keys = '.'.join(keys_arr[1:])
|
||||
return consume(child, child_keys, default=default)
|
||||
|
||||
|
||||
def run_local(filename, name, job_num=None):
|
||||
config = load_config(filename, name)
|
||||
if 'sweep' in config and config['sweep']['enable']:
|
||||
sweepC = config['sweep']
|
||||
del sweepC['enable']
|
||||
if consume(config, 'sweep.enable', False):
|
||||
sweepC = consume(config, 'sweep')
|
||||
project = consume(config, 'wandb.project')
|
||||
sweep_id = wandb.sweep(
|
||||
sweep=sweepC,
|
||||
project=config['project']
|
||||
project=project
|
||||
)
|
||||
wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent'])
|
||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
||||
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent'])
|
||||
else:
|
||||
consume(config, 'sweep', {})
|
||||
run_single(config)
|
||||
|
||||
|
||||
def run_from_sweep(orig_config, runnerName, project, wandbC):
|
||||
runner = Runners[runnerName]
|
||||
|
||||
with wandb.init(
|
||||
project=project,
|
||||
**wandbC
|
||||
) as run:
|
||||
config = copy.deepcopy(orig_config)
|
||||
deep_update(config, wandb.config)
|
||||
runner(run, config)
|
||||
|
||||
assert config == {}, ('Config was not completely consumed: ', config)
|
||||
|
||||
|
||||
def run_slurm(filename, name):
|
||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
||||
config = load_config(filename, name)
|
||||
@ -148,10 +179,7 @@ def run_slurm(filename, name):
|
||||
def run_single(config):
|
||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
||||
|
||||
try:
|
||||
runner = Runners[runnerName]
|
||||
except:
|
||||
d()
|
||||
runner = Runners[runnerName]
|
||||
|
||||
with wandb.init(
|
||||
project=consume(wandbC, 'project'),
|
||||
|
Loading…
Reference in New Issue
Block a user