diff --git a/slate/slate.py b/slate/slate.py index 938f030..6234be0 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -4,7 +4,9 @@ import os import math import random import copy -import collections.abc +import re +import itertools +from collections.abc import * from functools import partial from multiprocessing import Process from threading import Thread @@ -92,18 +94,18 @@ class Slate(): if k not in head: head[k] = {} head = head[k] - if isinstance(v, collections.abc.Mapping): + if isinstance(v, Mapping): last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation) else: last_head[ks[-1]] = v return d - def expand_vars(self, string, **kwargs): + def expand_vars(self, string, delta_desc='BASE', **kwargs): if isinstance(string, str): rand = int(random.random()*99999999) if string == '{rand}': return rand - return string.format(**kwargs, rand=rand) + return string.format(delta_desc=delta_desc, **kwargs, rand=rand) return string def apply_nested(self, d, f): @@ -122,11 +124,18 @@ class Slate(): self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) def consume(self, conf, key, default=None, expand=False, **kwargs): + if key == '': + if expand: + self.deep_expand_vars(conf, config=self._config, **kwargs) + elif type(conf) == str: + while conf.find('{') != -1: + conf = self.expand_vars(conf, config=self._config, **kwargs) + return conf keys_arr = key.split('.') if len(keys_arr) == 1: k = keys_arr[0] if default != None: - if isinstance(conf, collections.abc.Mapping): + if isinstance(conf, Mapping): val = conf.get(k, default) else: if default != None: @@ -220,7 +229,7 @@ class Slate(): f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n') def _fork_processes(self, config, rep_ids): - schedC = self.consume(config, 'scheduler') + schedC = self.consume(config, 'scheduler', {}) agents_per_job = self.consume(schedC, 'agents_per_job', 1) reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) @@ -271,31 +280,33 @@ class Slate(): def _run_single(self, orig_config, rep_ids, p_ind): print(f'[P{p_ind}] I will work on reps {rep_ids}') - runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) - project = self.consume(wandbC, 'project') + runnerName = self.consume(orig_config, 'runner') + project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name'))) Runner = self.runners[runnerName] for r in rep_ids: config = copy.deepcopy(orig_config) + runnerConf = self._make_config_for_run(config, r) + wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE')) with wandb.init( project=project, - config=copy.deepcopy(config), + config=copy.deepcopy(runnerConf), reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), **wandbC ) as run: - runner = Runner(self, config) + runner = Runner(self, runnerConf) runner.setup() runner.run(run) - if config != {}: - msg = ('Config was not completely consumed: ', config) + if runnerConf != {}: + msg = ('Config was not completely consumed: ', runnerConf) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) - orig_config = config + orig_config = {} def _run_from_sweep(self, orig_config, p_ind): runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) @@ -322,7 +333,31 @@ class Slate(): raise Exception(msg) else: print(msg) - orig_config = config + orig_config = {} + + def _make_config_for_run(self, config, r): + c = copy.deepcopy(config) + + grid_versions = self._make_grid_versions(c) + all_versions = self._make_ablative_versions(c, grid_versions) + + i = r % len(all_versions) + print(f'[d] Running version {i}/{len(all_versions)} in run {r}') + cur_conf = all_versions[i] + if 'ablative' in cur_conf: + del cur_conf['ablative'] + return cur_conf + + def _make_grid_versions(self, config): + if 'grid' in config: + return params_combine(config, 'grid', itertools.product) + return [config] + + def _make_ablative_versions(self, config, grid_versions): + if 'ablative' in config: + return grid_versions + ablative_expand(grid_versions) + else: + return grid_versions def from_args(self): import argparse @@ -350,6 +385,121 @@ class Slate(): self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id) +def params_combine(config: dict, key: str, iter_func): + if iter_func is None: + return [config] + + combined_configs = [] + # convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the + # value is the list of values + tuple_dict = flatten_dict_to_tuple_keys(config[key]) + _param_names = ['.'.join(t) for t in tuple_dict] + + param_lengths = map(len, tuple_dict.values()) + + # create a new config for each parameter setting + for values in iter_func(*tuple_dict.values()): + _config = copy.deepcopy(config) + + # Remove Grid/List Argument + del _config[key] + + # Expand Grid/List Parameters + for i, t in enumerate(tuple_dict.keys()): + insert_deep_dictionary(d=_config, t=t, value=values[i]) + + _config = extend_config_name(_config, _param_names, values) + combined_configs.append(_config) + return combined_configs + + +def ablative_expand(conf_list): + combined_configs = [] + for config in conf_list: + tuple_dict = flatten_dict_to_tuple_keys(config['ablative']) + _param_names = ['.'.join(t) for t in tuple_dict] + + for i, key in enumerate(tuple_dict): + for val in tuple_dict[key]: + _config = copy.deepcopy(config) + + insert_deep_dictionary( + _config, key, val + ) + + _config = extend_config_name(_config, [_param_names[i]], [val]) + combined_configs.append(_config) + return combined_configs + + +def flatten_dict_to_tuple_keys(d: MutableMapping): + flat_dict = {} + for k, v in d.items(): + if isinstance(v, MutableMapping): + sub_dict = flatten_dict_to_tuple_keys(v) + flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()}) + + elif isinstance(v, MutableSequence): + flat_dict[(k,)] = v + + return flat_dict + + +def insert_deep_dictionary(d: MutableMapping, t: tuple, value): + if type(t) is tuple: + if len(t) == 1: # tuple contains only one key + d[t[0]] = value + else: # tuple contains more than one key + if t[0] not in d: + d[t[0]] = dict() + insert_deep_dictionary(d[t[0]], t[1:], value) + else: + d[t] = value + + +def append_deep_dictionary(d: MutableMapping, t: tuple, value): + if type(t) is tuple: + if len(t) == 1: # tuple contains only one key + if t[0] not in d: + d[t[0]] = [] + d[t[0]].append(value) + else: # tuple contains more than one key + if t[0] not in d: + d[t[0]] = dict() + append_deep_dictionary(d[t[0]], t[1:], value) + else: + d[t] = value + + +def extend_config_name(config: dict, param_names: list, values: list) -> dict: + _converted_name = convert_param_names(param_names, values) + + config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name + return config + + +def convert_param_names(_param_names: list, values: list) -> str: + _converted_name = '_'.join("{}{}".format( + shorten_param(k), v) for k, v in zip(_param_names, values)) + # _converted_name = re.sub("[' \[\],()]", '', _converted_name) + _converted_name = re.sub("[' ]", '', _converted_name) + _converted_name = re.sub('["]', '', _converted_name) + _converted_name = re.sub("[(\[]", '_', _converted_name) + _converted_name = re.sub("[)\]]", '', _converted_name) + _converted_name = re.sub("[,]", '_', _converted_name) + return _converted_name + + +def shorten_param(_param_name): + name_parts = _param_name.split('.') + shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1])) + shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_'))) + if shortened_parts: + return shortened_parts + '.' + shortened_leaf + else: + return shortened_leaf + + class Slate_Runner(): def __init__(self, slate, config): self.slate = slate @@ -366,10 +516,9 @@ class Print_Config_Runner(Slate_Runner): def run(self, run): slate, config = self.slate, self.config - ptr = {'ptr': config} pprint(config) print('---') - pprint(slate.consume(ptr, 'ptr', expand=True)) + pprint(slate.consume(config, '', expand=True)) for k in list(config.keys()): del config[k]