[WIP] Implemented Grid and ablative

(functional but ugly)
This commit is contained in:
Dominik Moritz Roth 2023-07-29 13:03:01 +02:00
parent 74b06d92e7
commit 2248b24bcb

View File

@ -4,7 +4,9 @@ import os
import math
import random
import copy
import collections.abc
import re
import itertools
from collections.abc import *
from functools import partial
from multiprocessing import Process
from threading import Thread
@ -92,18 +94,18 @@ class Slate():
if k not in head:
head[k] = {}
head = head[k]
if isinstance(v, collections.abc.Mapping):
if isinstance(v, Mapping):
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
else:
last_head[ks[-1]] = v
return d
def expand_vars(self, string, **kwargs):
def expand_vars(self, string, delta_desc='BASE', **kwargs):
if isinstance(string, str):
rand = int(random.random()*99999999)
if string == '{rand}':
return rand
return string.format(**kwargs, rand=rand)
return string.format(delta_desc=delta_desc, **kwargs, rand=rand)
return string
def apply_nested(self, d, f):
@ -122,11 +124,18 @@ class Slate():
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
def consume(self, conf, key, default=None, expand=False, **kwargs):
if key == '':
if expand:
self.deep_expand_vars(conf, config=self._config, **kwargs)
elif type(conf) == str:
while conf.find('{') != -1:
conf = self.expand_vars(conf, config=self._config, **kwargs)
return conf
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
if default != None:
if isinstance(conf, collections.abc.Mapping):
if isinstance(conf, Mapping):
val = conf.get(k, default)
else:
if default != None:
@ -220,7 +229,7 @@ class Slate():
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
def _fork_processes(self, config, rep_ids):
schedC = self.consume(config, 'scheduler')
schedC = self.consume(config, 'scheduler', {})
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
@ -271,31 +280,33 @@ class Slate():
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
runnerName = self.consume(orig_config, 'runner')
project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name')))
Runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
runnerConf = self._make_config_for_run(config, r)
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
with wandb.init(
project=project,
config=copy.deepcopy(config),
config=copy.deepcopy(runnerConf),
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
**wandbC
) as run:
runner = Runner(self, config)
runner = Runner(self, runnerConf)
runner.setup()
runner.run(run)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if runnerConf != {}:
msg = ('Config was not completely consumed: ', runnerConf)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
orig_config = {}
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
@ -322,7 +333,31 @@ class Slate():
raise Exception(msg)
else:
print(msg)
orig_config = config
orig_config = {}
def _make_config_for_run(self, config, r):
c = copy.deepcopy(config)
grid_versions = self._make_grid_versions(c)
all_versions = self._make_ablative_versions(c, grid_versions)
i = r % len(all_versions)
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
cur_conf = all_versions[i]
if 'ablative' in cur_conf:
del cur_conf['ablative']
return cur_conf
def _make_grid_versions(self, config):
if 'grid' in config:
return params_combine(config, 'grid', itertools.product)
return [config]
def _make_ablative_versions(self, config, grid_versions):
if 'ablative' in config:
return grid_versions + ablative_expand(grid_versions)
else:
return grid_versions
def from_args(self):
import argparse
@ -350,6 +385,121 @@ class Slate():
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
def params_combine(config: dict, key: str, iter_func):
if iter_func is None:
return [config]
combined_configs = []
# convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
# value is the list of values
tuple_dict = flatten_dict_to_tuple_keys(config[key])
_param_names = ['.'.join(t) for t in tuple_dict]
param_lengths = map(len, tuple_dict.values())
# create a new config for each parameter setting
for values in iter_func(*tuple_dict.values()):
_config = copy.deepcopy(config)
# Remove Grid/List Argument
del _config[key]
# Expand Grid/List Parameters
for i, t in enumerate(tuple_dict.keys()):
insert_deep_dictionary(d=_config, t=t, value=values[i])
_config = extend_config_name(_config, _param_names, values)
combined_configs.append(_config)
return combined_configs
def ablative_expand(conf_list):
combined_configs = []
for config in conf_list:
tuple_dict = flatten_dict_to_tuple_keys(config['ablative'])
_param_names = ['.'.join(t) for t in tuple_dict]
for i, key in enumerate(tuple_dict):
for val in tuple_dict[key]:
_config = copy.deepcopy(config)
insert_deep_dictionary(
_config, key, val
)
_config = extend_config_name(_config, [_param_names[i]], [val])
combined_configs.append(_config)
return combined_configs
def flatten_dict_to_tuple_keys(d: MutableMapping):
flat_dict = {}
for k, v in d.items():
if isinstance(v, MutableMapping):
sub_dict = flatten_dict_to_tuple_keys(v)
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
elif isinstance(v, MutableSequence):
flat_dict[(k,)] = v
return flat_dict
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
if type(t) is tuple:
if len(t) == 1: # tuple contains only one key
d[t[0]] = value
else: # tuple contains more than one key
if t[0] not in d:
d[t[0]] = dict()
insert_deep_dictionary(d[t[0]], t[1:], value)
else:
d[t] = value
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
if type(t) is tuple:
if len(t) == 1: # tuple contains only one key
if t[0] not in d:
d[t[0]] = []
d[t[0]].append(value)
else: # tuple contains more than one key
if t[0] not in d:
d[t[0]] = dict()
append_deep_dictionary(d[t[0]], t[1:], value)
else:
d[t] = value
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
_converted_name = convert_param_names(param_names, values)
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
return config
def convert_param_names(_param_names: list, values: list) -> str:
_converted_name = '_'.join("{}{}".format(
shorten_param(k), v) for k, v in zip(_param_names, values))
# _converted_name = re.sub("[' \[\],()]", '', _converted_name)
_converted_name = re.sub("[' ]", '', _converted_name)
_converted_name = re.sub('["]', '', _converted_name)
_converted_name = re.sub("[(\[]", '_', _converted_name)
_converted_name = re.sub("[)\]]", '', _converted_name)
_converted_name = re.sub("[,]", '_', _converted_name)
return _converted_name
def shorten_param(_param_name):
name_parts = _param_name.split('.')
shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1]))
shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_')))
if shortened_parts:
return shortened_parts + '.' + shortened_leaf
else:
return shortened_leaf
class Slate_Runner():
def __init__(self, slate, config):
self.slate = slate
@ -366,10 +516,9 @@ class Print_Config_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
ptr = {'ptr': config}
pprint(config)
print('---')
pprint(slate.consume(ptr, 'ptr', expand=True))
pprint(slate.consume(config, '', expand=True))
for k in list(config.keys()):
del config[k]