[WIP] Implemented Grid and ablative
(functional but ugly)
This commit is contained in:
parent
74b06d92e7
commit
2248b24bcb
181
slate/slate.py
181
slate/slate.py
@ -4,7 +4,9 @@ import os
|
||||
import math
|
||||
import random
|
||||
import copy
|
||||
import collections.abc
|
||||
import re
|
||||
import itertools
|
||||
from collections.abc import *
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from threading import Thread
|
||||
@ -92,18 +94,18 @@ class Slate():
|
||||
if k not in head:
|
||||
head[k] = {}
|
||||
head = head[k]
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
if isinstance(v, Mapping):
|
||||
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
|
||||
else:
|
||||
last_head[ks[-1]] = v
|
||||
return d
|
||||
|
||||
def expand_vars(self, string, **kwargs):
|
||||
def expand_vars(self, string, delta_desc='BASE', **kwargs):
|
||||
if isinstance(string, str):
|
||||
rand = int(random.random()*99999999)
|
||||
if string == '{rand}':
|
||||
return rand
|
||||
return string.format(**kwargs, rand=rand)
|
||||
return string.format(delta_desc=delta_desc, **kwargs, rand=rand)
|
||||
return string
|
||||
|
||||
def apply_nested(self, d, f):
|
||||
@ -122,11 +124,18 @@ class Slate():
|
||||
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
||||
|
||||
def consume(self, conf, key, default=None, expand=False, **kwargs):
|
||||
if key == '':
|
||||
if expand:
|
||||
self.deep_expand_vars(conf, config=self._config, **kwargs)
|
||||
elif type(conf) == str:
|
||||
while conf.find('{') != -1:
|
||||
conf = self.expand_vars(conf, config=self._config, **kwargs)
|
||||
return conf
|
||||
keys_arr = key.split('.')
|
||||
if len(keys_arr) == 1:
|
||||
k = keys_arr[0]
|
||||
if default != None:
|
||||
if isinstance(conf, collections.abc.Mapping):
|
||||
if isinstance(conf, Mapping):
|
||||
val = conf.get(k, default)
|
||||
else:
|
||||
if default != None:
|
||||
@ -220,7 +229,7 @@ class Slate():
|
||||
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
|
||||
|
||||
def _fork_processes(self, config, rep_ids):
|
||||
schedC = self.consume(config, 'scheduler')
|
||||
schedC = self.consume(config, 'scheduler', {})
|
||||
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
||||
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
||||
|
||||
@ -271,31 +280,33 @@ class Slate():
|
||||
|
||||
def _run_single(self, orig_config, rep_ids, p_ind):
|
||||
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||
project = self.consume(wandbC, 'project')
|
||||
runnerName = self.consume(orig_config, 'runner')
|
||||
project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name')))
|
||||
|
||||
Runner = self.runners[runnerName]
|
||||
|
||||
for r in rep_ids:
|
||||
config = copy.deepcopy(orig_config)
|
||||
runnerConf = self._make_config_for_run(config, r)
|
||||
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
|
||||
with wandb.init(
|
||||
project=project,
|
||||
config=copy.deepcopy(config),
|
||||
config=copy.deepcopy(runnerConf),
|
||||
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
||||
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
|
||||
**wandbC
|
||||
) as run:
|
||||
runner = Runner(self, config)
|
||||
runner = Runner(self, runnerConf)
|
||||
runner.setup()
|
||||
runner.run(run)
|
||||
|
||||
if config != {}:
|
||||
msg = ('Config was not completely consumed: ', config)
|
||||
if runnerConf != {}:
|
||||
msg = ('Config was not completely consumed: ', runnerConf)
|
||||
if REQUIRE_CONFIG_CONSUMED:
|
||||
raise Exception(msg)
|
||||
else:
|
||||
print(msg)
|
||||
orig_config = config
|
||||
orig_config = {}
|
||||
|
||||
def _run_from_sweep(self, orig_config, p_ind):
|
||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||
@ -322,7 +333,31 @@ class Slate():
|
||||
raise Exception(msg)
|
||||
else:
|
||||
print(msg)
|
||||
orig_config = config
|
||||
orig_config = {}
|
||||
|
||||
def _make_config_for_run(self, config, r):
|
||||
c = copy.deepcopy(config)
|
||||
|
||||
grid_versions = self._make_grid_versions(c)
|
||||
all_versions = self._make_ablative_versions(c, grid_versions)
|
||||
|
||||
i = r % len(all_versions)
|
||||
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
|
||||
cur_conf = all_versions[i]
|
||||
if 'ablative' in cur_conf:
|
||||
del cur_conf['ablative']
|
||||
return cur_conf
|
||||
|
||||
def _make_grid_versions(self, config):
|
||||
if 'grid' in config:
|
||||
return params_combine(config, 'grid', itertools.product)
|
||||
return [config]
|
||||
|
||||
def _make_ablative_versions(self, config, grid_versions):
|
||||
if 'ablative' in config:
|
||||
return grid_versions + ablative_expand(grid_versions)
|
||||
else:
|
||||
return grid_versions
|
||||
|
||||
def from_args(self):
|
||||
import argparse
|
||||
@ -350,6 +385,121 @@ class Slate():
|
||||
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
|
||||
|
||||
|
||||
def params_combine(config: dict, key: str, iter_func):
|
||||
if iter_func is None:
|
||||
return [config]
|
||||
|
||||
combined_configs = []
|
||||
# convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
|
||||
# value is the list of values
|
||||
tuple_dict = flatten_dict_to_tuple_keys(config[key])
|
||||
_param_names = ['.'.join(t) for t in tuple_dict]
|
||||
|
||||
param_lengths = map(len, tuple_dict.values())
|
||||
|
||||
# create a new config for each parameter setting
|
||||
for values in iter_func(*tuple_dict.values()):
|
||||
_config = copy.deepcopy(config)
|
||||
|
||||
# Remove Grid/List Argument
|
||||
del _config[key]
|
||||
|
||||
# Expand Grid/List Parameters
|
||||
for i, t in enumerate(tuple_dict.keys()):
|
||||
insert_deep_dictionary(d=_config, t=t, value=values[i])
|
||||
|
||||
_config = extend_config_name(_config, _param_names, values)
|
||||
combined_configs.append(_config)
|
||||
return combined_configs
|
||||
|
||||
|
||||
def ablative_expand(conf_list):
|
||||
combined_configs = []
|
||||
for config in conf_list:
|
||||
tuple_dict = flatten_dict_to_tuple_keys(config['ablative'])
|
||||
_param_names = ['.'.join(t) for t in tuple_dict]
|
||||
|
||||
for i, key in enumerate(tuple_dict):
|
||||
for val in tuple_dict[key]:
|
||||
_config = copy.deepcopy(config)
|
||||
|
||||
insert_deep_dictionary(
|
||||
_config, key, val
|
||||
)
|
||||
|
||||
_config = extend_config_name(_config, [_param_names[i]], [val])
|
||||
combined_configs.append(_config)
|
||||
return combined_configs
|
||||
|
||||
|
||||
def flatten_dict_to_tuple_keys(d: MutableMapping):
|
||||
flat_dict = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, MutableMapping):
|
||||
sub_dict = flatten_dict_to_tuple_keys(v)
|
||||
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
|
||||
|
||||
elif isinstance(v, MutableSequence):
|
||||
flat_dict[(k,)] = v
|
||||
|
||||
return flat_dict
|
||||
|
||||
|
||||
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||
if type(t) is tuple:
|
||||
if len(t) == 1: # tuple contains only one key
|
||||
d[t[0]] = value
|
||||
else: # tuple contains more than one key
|
||||
if t[0] not in d:
|
||||
d[t[0]] = dict()
|
||||
insert_deep_dictionary(d[t[0]], t[1:], value)
|
||||
else:
|
||||
d[t] = value
|
||||
|
||||
|
||||
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||
if type(t) is tuple:
|
||||
if len(t) == 1: # tuple contains only one key
|
||||
if t[0] not in d:
|
||||
d[t[0]] = []
|
||||
d[t[0]].append(value)
|
||||
else: # tuple contains more than one key
|
||||
if t[0] not in d:
|
||||
d[t[0]] = dict()
|
||||
append_deep_dictionary(d[t[0]], t[1:], value)
|
||||
else:
|
||||
d[t] = value
|
||||
|
||||
|
||||
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
|
||||
_converted_name = convert_param_names(param_names, values)
|
||||
|
||||
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
|
||||
return config
|
||||
|
||||
|
||||
def convert_param_names(_param_names: list, values: list) -> str:
|
||||
_converted_name = '_'.join("{}{}".format(
|
||||
shorten_param(k), v) for k, v in zip(_param_names, values))
|
||||
# _converted_name = re.sub("[' \[\],()]", '', _converted_name)
|
||||
_converted_name = re.sub("[' ]", '', _converted_name)
|
||||
_converted_name = re.sub('["]', '', _converted_name)
|
||||
_converted_name = re.sub("[(\[]", '_', _converted_name)
|
||||
_converted_name = re.sub("[)\]]", '', _converted_name)
|
||||
_converted_name = re.sub("[,]", '_', _converted_name)
|
||||
return _converted_name
|
||||
|
||||
|
||||
def shorten_param(_param_name):
|
||||
name_parts = _param_name.split('.')
|
||||
shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1]))
|
||||
shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_')))
|
||||
if shortened_parts:
|
||||
return shortened_parts + '.' + shortened_leaf
|
||||
else:
|
||||
return shortened_leaf
|
||||
|
||||
|
||||
class Slate_Runner():
|
||||
def __init__(self, slate, config):
|
||||
self.slate = slate
|
||||
@ -366,10 +516,9 @@ class Print_Config_Runner(Slate_Runner):
|
||||
def run(self, run):
|
||||
slate, config = self.slate, self.config
|
||||
|
||||
ptr = {'ptr': config}
|
||||
pprint(config)
|
||||
print('---')
|
||||
pprint(slate.consume(ptr, 'ptr', expand=True))
|
||||
pprint(slate.consume(config, '', expand=True))
|
||||
for k in list(config.keys()):
|
||||
del config[k]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user