Sweeps should now be functional
This commit is contained in:
parent
b7ad796671
commit
4b395681d0
58
main.py
58
main.py
@ -10,11 +10,11 @@ import os
|
|||||||
import random
|
import random
|
||||||
import copy
|
import copy
|
||||||
import collections.abc
|
import collections.abc
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pdb
|
import pdb
|
||||||
d = pdb.set_trace
|
d = pdb.set_trace
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pyslurm
|
import pyslurm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -22,15 +22,14 @@ except ImportError:
|
|||||||
else:
|
else:
|
||||||
slurm_avaible = True
|
slurm_avaible = True
|
||||||
|
|
||||||
|
|
||||||
PCA = None
|
PCA = None
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement Testing
|
# TODO: Implement Testing
|
||||||
# TODO: Implement PCA
|
# TODO: Implement PCA
|
||||||
# TODO: Implement Slurm
|
# TODO: Implement Slurm
|
||||||
# TODO: Implement Parallel
|
# TODO: Implement Parallel
|
||||||
|
|
||||||
|
|
||||||
def load_config(filename, name):
|
def load_config(filename, name):
|
||||||
config = _load_config(filename, name)
|
config = _load_config(filename, name)
|
||||||
deep_expand_vars(config, config=config)
|
deep_expand_vars(config, config=config)
|
||||||
@ -48,6 +47,8 @@ def _load_config(filename, name):
|
|||||||
imports = reversed(doc['import'].split(','))
|
imports = reversed(doc['import'].split(','))
|
||||||
del doc['import']
|
del doc['import']
|
||||||
for imp in imports:
|
for imp in imports:
|
||||||
|
if imp == "$":
|
||||||
|
imp = ':DEFAULT'
|
||||||
rel_path, *opt = imp.split(':')
|
rel_path, *opt = imp.split(':')
|
||||||
if len(opt) == 0:
|
if len(opt) == 0:
|
||||||
nested_name = 'DEFAULT'
|
nested_name = 'DEFAULT'
|
||||||
@ -62,15 +63,29 @@ def _load_config(filename, name):
|
|||||||
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
||||||
|
|
||||||
|
|
||||||
def deep_update(d, u):
|
def deep_update_old(d, u):
|
||||||
for k, v in u.items():
|
for k, v in u.items():
|
||||||
if isinstance(v, collections.abc.Mapping):
|
if isinstance(v, collections.abc.Mapping):
|
||||||
d[k] = deep_update(d.get(k, {}), v)
|
d[k] = deep_update_old(d.get(k, {}), v)
|
||||||
else:
|
else:
|
||||||
d[k] = v
|
d[k] = v
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def deep_update(d, u):
|
||||||
|
for kstr, v in u.items():
|
||||||
|
ks = kstr.split('.')
|
||||||
|
head = d
|
||||||
|
for k in ks:
|
||||||
|
last_head = head
|
||||||
|
head = head[k]
|
||||||
|
if isinstance(v, collections.abc.Mapping):
|
||||||
|
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
|
||||||
|
else:
|
||||||
|
last_head[ks[-1]] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def expand_vars(string, **kwargs):
|
def expand_vars(string, **kwargs):
|
||||||
if isinstance(string, str):
|
if isinstance(string, str):
|
||||||
return string.format(**kwargs)
|
return string.format(**kwargs)
|
||||||
@ -103,25 +118,41 @@ def consume(conf, keys, default=None):
|
|||||||
if k in conf:
|
if k in conf:
|
||||||
del conf[k]
|
del conf[k]
|
||||||
return val
|
return val
|
||||||
child = conf[keys_arr[0]]
|
child = conf.get(keys_arr[0], {})
|
||||||
child_keys = '.'.join(keys_arr[1:])
|
child_keys = '.'.join(keys_arr[1:])
|
||||||
return consume(child, child_keys, default=default)
|
return consume(child, child_keys, default=default)
|
||||||
|
|
||||||
|
|
||||||
def run_local(filename, name, job_num=None):
|
def run_local(filename, name, job_num=None):
|
||||||
config = load_config(filename, name)
|
config = load_config(filename, name)
|
||||||
if 'sweep' in config and config['sweep']['enable']:
|
if consume(config, 'sweep.enable', False):
|
||||||
sweepC = config['sweep']
|
sweepC = consume(config, 'sweep')
|
||||||
del sweepC['enable']
|
project = consume(config, 'wandb.project')
|
||||||
sweep_id = wandb.sweep(
|
sweep_id = wandb.sweep(
|
||||||
sweep=sweepC,
|
sweep=sweepC,
|
||||||
project=config['project']
|
project=project
|
||||||
)
|
)
|
||||||
wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent'])
|
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
||||||
|
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent'])
|
||||||
else:
|
else:
|
||||||
|
consume(config, 'sweep', {})
|
||||||
run_single(config)
|
run_single(config)
|
||||||
|
|
||||||
|
|
||||||
|
def run_from_sweep(orig_config, runnerName, project, wandbC):
|
||||||
|
runner = Runners[runnerName]
|
||||||
|
|
||||||
|
with wandb.init(
|
||||||
|
project=project,
|
||||||
|
**wandbC
|
||||||
|
) as run:
|
||||||
|
config = copy.deepcopy(orig_config)
|
||||||
|
deep_update(config, wandb.config)
|
||||||
|
runner(run, config)
|
||||||
|
|
||||||
|
assert config == {}, ('Config was not completely consumed: ', config)
|
||||||
|
|
||||||
|
|
||||||
def run_slurm(filename, name):
|
def run_slurm(filename, name):
|
||||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
||||||
config = load_config(filename, name)
|
config = load_config(filename, name)
|
||||||
@ -148,10 +179,7 @@ def run_slurm(filename, name):
|
|||||||
def run_single(config):
|
def run_single(config):
|
||||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
||||||
|
|
||||||
try:
|
runner = Runners[runnerName]
|
||||||
runner = Runners[runnerName]
|
|
||||||
except:
|
|
||||||
d()
|
|
||||||
|
|
||||||
with wandb.init(
|
with wandb.init(
|
||||||
project=consume(wandbC, 'project'),
|
project=consume(wandbC, 'project'),
|
||||||
|
Loading…
Reference in New Issue
Block a user