Sweeps should now be functional

This commit is contained in:
Dominik Moritz Roth 2023-07-05 20:30:57 +02:00
parent b7ad796671
commit 4b395681d0

58
main.py
View File

@ -10,11 +10,11 @@ import os
import random
import copy
import collections.abc
from functools import partial
import pdb
d = pdb.set_trace
try:
import pyslurm
except ImportError:
@ -22,15 +22,14 @@ except ImportError:
else:
slurm_avaible = True
PCA = None
# TODO: Implement Testing
# TODO: Implement PCA
# TODO: Implement Slurm
# TODO: Implement Parallel
def load_config(filename, name):
config = _load_config(filename, name)
deep_expand_vars(config, config=config)
@ -48,6 +47,8 @@ def _load_config(filename, name):
imports = reversed(doc['import'].split(','))
del doc['import']
for imp in imports:
if imp == "$":
imp = ':DEFAULT'
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
@ -62,15 +63,29 @@ def _load_config(filename, name):
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update(d, u):
def deep_update_old(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
d[k] = deep_update_old(d.get(k, {}), v)
else:
d[k] = v
return d
def deep_update(d, u):
for kstr, v in u.items():
ks = kstr.split('.')
head = d
for k in ks:
last_head = head
head = head[k]
if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
else:
last_head[ks[-1]] = v
return d
def expand_vars(string, **kwargs):
if isinstance(string, str):
return string.format(**kwargs)
@ -103,25 +118,41 @@ def consume(conf, keys, default=None):
if k in conf:
del conf[k]
return val
child = conf[keys_arr[0]]
child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:])
return consume(child, child_keys, default=default)
def run_local(filename, name, job_num=None):
config = load_config(filename, name)
if 'sweep' in config and config['sweep']['enable']:
sweepC = config['sweep']
del sweepC['enable']
if consume(config, 'sweep.enable', False):
sweepC = consume(config, 'sweep')
project = consume(config, 'wandb.project')
sweep_id = wandb.sweep(
sweep=sweepC,
project=config['project']
project=project
)
wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent'])
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['scheduler']['reps_per_agent'])
else:
consume(config, 'sweep', {})
run_single(config)
def run_from_sweep(orig_config, runnerName, project, wandbC):
runner = Runners[runnerName]
with wandb.init(
project=project,
**wandbC
) as run:
config = copy.deepcopy(orig_config)
deep_update(config, wandb.config)
runner(run, config)
assert config == {}, ('Config was not completely consumed: ', config)
def run_slurm(filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = load_config(filename, name)
@ -148,10 +179,7 @@ def run_slurm(filename, name):
def run_single(config):
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
try:
runner = Runners[runnerName]
except:
d()
runner = Runners[runnerName]
with wandb.init(
project=consume(wandbC, 'project'),