[WIP] Implemented Grid and ablative
(functional but ugly)
This commit is contained in:
parent
74b06d92e7
commit
2248b24bcb
181
slate/slate.py
181
slate/slate.py
@ -4,7 +4,9 @@ import os
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import copy
|
import copy
|
||||||
import collections.abc
|
import re
|
||||||
|
import itertools
|
||||||
|
from collections.abc import *
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -92,18 +94,18 @@ class Slate():
|
|||||||
if k not in head:
|
if k not in head:
|
||||||
head[k] = {}
|
head[k] = {}
|
||||||
head = head[k]
|
head = head[k]
|
||||||
if isinstance(v, collections.abc.Mapping):
|
if isinstance(v, Mapping):
|
||||||
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
|
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
|
||||||
else:
|
else:
|
||||||
last_head[ks[-1]] = v
|
last_head[ks[-1]] = v
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def expand_vars(self, string, **kwargs):
|
def expand_vars(self, string, delta_desc='BASE', **kwargs):
|
||||||
if isinstance(string, str):
|
if isinstance(string, str):
|
||||||
rand = int(random.random()*99999999)
|
rand = int(random.random()*99999999)
|
||||||
if string == '{rand}':
|
if string == '{rand}':
|
||||||
return rand
|
return rand
|
||||||
return string.format(**kwargs, rand=rand)
|
return string.format(delta_desc=delta_desc, **kwargs, rand=rand)
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def apply_nested(self, d, f):
|
def apply_nested(self, d, f):
|
||||||
@ -122,11 +124,18 @@ class Slate():
|
|||||||
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
||||||
|
|
||||||
def consume(self, conf, key, default=None, expand=False, **kwargs):
|
def consume(self, conf, key, default=None, expand=False, **kwargs):
|
||||||
|
if key == '':
|
||||||
|
if expand:
|
||||||
|
self.deep_expand_vars(conf, config=self._config, **kwargs)
|
||||||
|
elif type(conf) == str:
|
||||||
|
while conf.find('{') != -1:
|
||||||
|
conf = self.expand_vars(conf, config=self._config, **kwargs)
|
||||||
|
return conf
|
||||||
keys_arr = key.split('.')
|
keys_arr = key.split('.')
|
||||||
if len(keys_arr) == 1:
|
if len(keys_arr) == 1:
|
||||||
k = keys_arr[0]
|
k = keys_arr[0]
|
||||||
if default != None:
|
if default != None:
|
||||||
if isinstance(conf, collections.abc.Mapping):
|
if isinstance(conf, Mapping):
|
||||||
val = conf.get(k, default)
|
val = conf.get(k, default)
|
||||||
else:
|
else:
|
||||||
if default != None:
|
if default != None:
|
||||||
@ -220,7 +229,7 @@ class Slate():
|
|||||||
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
|
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
|
||||||
|
|
||||||
def _fork_processes(self, config, rep_ids):
|
def _fork_processes(self, config, rep_ids):
|
||||||
schedC = self.consume(config, 'scheduler')
|
schedC = self.consume(config, 'scheduler', {})
|
||||||
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
||||||
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
||||||
|
|
||||||
@ -271,31 +280,33 @@ class Slate():
|
|||||||
|
|
||||||
def _run_single(self, orig_config, rep_ids, p_ind):
|
def _run_single(self, orig_config, rep_ids, p_ind):
|
||||||
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
||||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
runnerName = self.consume(orig_config, 'runner')
|
||||||
project = self.consume(wandbC, 'project')
|
project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name')))
|
||||||
|
|
||||||
Runner = self.runners[runnerName]
|
Runner = self.runners[runnerName]
|
||||||
|
|
||||||
for r in rep_ids:
|
for r in rep_ids:
|
||||||
config = copy.deepcopy(orig_config)
|
config = copy.deepcopy(orig_config)
|
||||||
|
runnerConf = self._make_config_for_run(config, r)
|
||||||
|
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
|
||||||
with wandb.init(
|
with wandb.init(
|
||||||
project=project,
|
project=project,
|
||||||
config=copy.deepcopy(config),
|
config=copy.deepcopy(runnerConf),
|
||||||
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
||||||
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
|
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
|
||||||
**wandbC
|
**wandbC
|
||||||
) as run:
|
) as run:
|
||||||
runner = Runner(self, config)
|
runner = Runner(self, runnerConf)
|
||||||
runner.setup()
|
runner.setup()
|
||||||
runner.run(run)
|
runner.run(run)
|
||||||
|
|
||||||
if config != {}:
|
if runnerConf != {}:
|
||||||
msg = ('Config was not completely consumed: ', config)
|
msg = ('Config was not completely consumed: ', runnerConf)
|
||||||
if REQUIRE_CONFIG_CONSUMED:
|
if REQUIRE_CONFIG_CONSUMED:
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
else:
|
else:
|
||||||
print(msg)
|
print(msg)
|
||||||
orig_config = config
|
orig_config = {}
|
||||||
|
|
||||||
def _run_from_sweep(self, orig_config, p_ind):
|
def _run_from_sweep(self, orig_config, p_ind):
|
||||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||||
@ -322,7 +333,31 @@ class Slate():
|
|||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
else:
|
else:
|
||||||
print(msg)
|
print(msg)
|
||||||
orig_config = config
|
orig_config = {}
|
||||||
|
|
||||||
|
def _make_config_for_run(self, config, r):
|
||||||
|
c = copy.deepcopy(config)
|
||||||
|
|
||||||
|
grid_versions = self._make_grid_versions(c)
|
||||||
|
all_versions = self._make_ablative_versions(c, grid_versions)
|
||||||
|
|
||||||
|
i = r % len(all_versions)
|
||||||
|
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
|
||||||
|
cur_conf = all_versions[i]
|
||||||
|
if 'ablative' in cur_conf:
|
||||||
|
del cur_conf['ablative']
|
||||||
|
return cur_conf
|
||||||
|
|
||||||
|
def _make_grid_versions(self, config):
|
||||||
|
if 'grid' in config:
|
||||||
|
return params_combine(config, 'grid', itertools.product)
|
||||||
|
return [config]
|
||||||
|
|
||||||
|
def _make_ablative_versions(self, config, grid_versions):
|
||||||
|
if 'ablative' in config:
|
||||||
|
return grid_versions + ablative_expand(grid_versions)
|
||||||
|
else:
|
||||||
|
return grid_versions
|
||||||
|
|
||||||
def from_args(self):
|
def from_args(self):
|
||||||
import argparse
|
import argparse
|
||||||
@ -350,6 +385,121 @@ class Slate():
|
|||||||
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
|
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
|
||||||
|
|
||||||
|
|
||||||
|
def params_combine(config: dict, key: str, iter_func):
|
||||||
|
if iter_func is None:
|
||||||
|
return [config]
|
||||||
|
|
||||||
|
combined_configs = []
|
||||||
|
# convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
|
||||||
|
# value is the list of values
|
||||||
|
tuple_dict = flatten_dict_to_tuple_keys(config[key])
|
||||||
|
_param_names = ['.'.join(t) for t in tuple_dict]
|
||||||
|
|
||||||
|
param_lengths = map(len, tuple_dict.values())
|
||||||
|
|
||||||
|
# create a new config for each parameter setting
|
||||||
|
for values in iter_func(*tuple_dict.values()):
|
||||||
|
_config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
# Remove Grid/List Argument
|
||||||
|
del _config[key]
|
||||||
|
|
||||||
|
# Expand Grid/List Parameters
|
||||||
|
for i, t in enumerate(tuple_dict.keys()):
|
||||||
|
insert_deep_dictionary(d=_config, t=t, value=values[i])
|
||||||
|
|
||||||
|
_config = extend_config_name(_config, _param_names, values)
|
||||||
|
combined_configs.append(_config)
|
||||||
|
return combined_configs
|
||||||
|
|
||||||
|
|
||||||
|
def ablative_expand(conf_list):
|
||||||
|
combined_configs = []
|
||||||
|
for config in conf_list:
|
||||||
|
tuple_dict = flatten_dict_to_tuple_keys(config['ablative'])
|
||||||
|
_param_names = ['.'.join(t) for t in tuple_dict]
|
||||||
|
|
||||||
|
for i, key in enumerate(tuple_dict):
|
||||||
|
for val in tuple_dict[key]:
|
||||||
|
_config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
insert_deep_dictionary(
|
||||||
|
_config, key, val
|
||||||
|
)
|
||||||
|
|
||||||
|
_config = extend_config_name(_config, [_param_names[i]], [val])
|
||||||
|
combined_configs.append(_config)
|
||||||
|
return combined_configs
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_dict_to_tuple_keys(d: MutableMapping):
|
||||||
|
flat_dict = {}
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, MutableMapping):
|
||||||
|
sub_dict = flatten_dict_to_tuple_keys(v)
|
||||||
|
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
|
||||||
|
|
||||||
|
elif isinstance(v, MutableSequence):
|
||||||
|
flat_dict[(k,)] = v
|
||||||
|
|
||||||
|
return flat_dict
|
||||||
|
|
||||||
|
|
||||||
|
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||||
|
if type(t) is tuple:
|
||||||
|
if len(t) == 1: # tuple contains only one key
|
||||||
|
d[t[0]] = value
|
||||||
|
else: # tuple contains more than one key
|
||||||
|
if t[0] not in d:
|
||||||
|
d[t[0]] = dict()
|
||||||
|
insert_deep_dictionary(d[t[0]], t[1:], value)
|
||||||
|
else:
|
||||||
|
d[t] = value
|
||||||
|
|
||||||
|
|
||||||
|
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||||
|
if type(t) is tuple:
|
||||||
|
if len(t) == 1: # tuple contains only one key
|
||||||
|
if t[0] not in d:
|
||||||
|
d[t[0]] = []
|
||||||
|
d[t[0]].append(value)
|
||||||
|
else: # tuple contains more than one key
|
||||||
|
if t[0] not in d:
|
||||||
|
d[t[0]] = dict()
|
||||||
|
append_deep_dictionary(d[t[0]], t[1:], value)
|
||||||
|
else:
|
||||||
|
d[t] = value
|
||||||
|
|
||||||
|
|
||||||
|
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
|
||||||
|
_converted_name = convert_param_names(param_names, values)
|
||||||
|
|
||||||
|
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def convert_param_names(_param_names: list, values: list) -> str:
|
||||||
|
_converted_name = '_'.join("{}{}".format(
|
||||||
|
shorten_param(k), v) for k, v in zip(_param_names, values))
|
||||||
|
# _converted_name = re.sub("[' \[\],()]", '', _converted_name)
|
||||||
|
_converted_name = re.sub("[' ]", '', _converted_name)
|
||||||
|
_converted_name = re.sub('["]', '', _converted_name)
|
||||||
|
_converted_name = re.sub("[(\[]", '_', _converted_name)
|
||||||
|
_converted_name = re.sub("[)\]]", '', _converted_name)
|
||||||
|
_converted_name = re.sub("[,]", '_', _converted_name)
|
||||||
|
return _converted_name
|
||||||
|
|
||||||
|
|
||||||
|
def shorten_param(_param_name):
|
||||||
|
name_parts = _param_name.split('.')
|
||||||
|
shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1]))
|
||||||
|
shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_')))
|
||||||
|
if shortened_parts:
|
||||||
|
return shortened_parts + '.' + shortened_leaf
|
||||||
|
else:
|
||||||
|
return shortened_leaf
|
||||||
|
|
||||||
|
|
||||||
class Slate_Runner():
|
class Slate_Runner():
|
||||||
def __init__(self, slate, config):
|
def __init__(self, slate, config):
|
||||||
self.slate = slate
|
self.slate = slate
|
||||||
@ -366,10 +516,9 @@ class Print_Config_Runner(Slate_Runner):
|
|||||||
def run(self, run):
|
def run(self, run):
|
||||||
slate, config = self.slate, self.config
|
slate, config = self.slate, self.config
|
||||||
|
|
||||||
ptr = {'ptr': config}
|
|
||||||
pprint(config)
|
pprint(config)
|
||||||
print('---')
|
print('---')
|
||||||
pprint(slate.consume(ptr, 'ptr', expand=True))
|
pprint(slate.consume(config, '', expand=True))
|
||||||
for k in list(config.keys()):
|
for k in list(config.keys()):
|
||||||
del config[k]
|
del config[k]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user