import wandb import yaml import os import math import time import random import copy import re import itertools from collections.abc import * from functools import partial from multiprocessing import Process from threading import Thread import git import datetime from pprint import pprint import pdb d = pdb.set_trace REQUIRE_CONFIG_CONSUMED = False DEFAULT_START_METHOD = 'fork' DEFAULT_REINIT = True Parallelization_Primitive = Process # Thread try: import pyslurm except ImportError: slurm_avaible = False print('[!] Slurm not avaible.') else: slurm_avaible = True # TODO: Implement Testing # TODO: Implement Ablative class Slate(): def __init__(self, runners): self.runners = { 'void': Void_Runner, 'printConfig': Print_Config_Runner, 'pdb': PDB_Runner, } self.runners.update(runners) self._version = False self.job_id = os.environ.get('SLURM_JOB_ID', False) self.task_id = None self.run_id = -1 self._tmp_path = os.path.expandvars('$TMP') self.sweep_id = None self.verify = False def load_config(self, filename, name): config, stack = self._load_config(filename, name) print('[i] Merged Configs: ', stack) self._config = copy.deepcopy(config) self.consume(config, 'vars', {}) return config def _load_config(self, filename, name, stack=[]): stack.append(f'{filename}:{name}') with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: if 'name' in doc: if doc['name'] == name: if 'import' in doc: imports = doc['import'].split(',') del doc['import'] for imp in imports: if imp[0] == ' ': imp = imp[1:] if imp == "$": imp = ':DEFAULT' rel_path, *opt = imp.split(':') if len(opt) == 0: nested_name = 'DEFAULT' elif len(opt) == 1: nested_name = opt[0] else: raise Exception('Malformed import statement. Must be , , for file:DEFAULT or for :DEFAULT.') nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename child, stack = self._load_config(nested_path, nested_name, stack=stack) doc = self.deep_update(child, doc) return doc, stack raise Exception(f'Unable to find experiment <{name}> in <{filename}>') def deep_update(self, d, u, traverse_dot_notation=True): for kstr, v in u.items(): if traverse_dot_notation: ks = kstr.split('.') else: ks = [kstr] head = d for k in ks: if k in ['parameters']: traverse_dot_notation = False last_head = head if k not in head: head[k] = {} head = head[k] if isinstance(v, Mapping): last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation) else: last_head[ks[-1]] = v return d def expand_vars(self, string, delta_desc='BASE', **kwargs): if isinstance(string, str): rand = int(random.random()*99999999) if string == '{rand}': return rand return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id) return string def apply_nested(self, d, f): for k, v in d.items(): if isinstance(v, dict): self.apply_nested(v, f) elif isinstance(v, list): for i, e in enumerate(v): ptr = {'PTR': d[k][i]} self.apply_nested(ptr, f) d[k][i] = ptr['PTR'] else: d[k] = f(v) def deep_expand_vars(self, dict, **kwargs): self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) def consume(self, conf, key, default=None, expand=False, **kwargs): if key == '': if expand: self.deep_expand_vars(conf, config=self._config, **kwargs) elif type(conf) == str: while conf.find('{') != -1: conf = self.expand_vars(conf, config=self._config, **kwargs) return conf keys_arr = key.split('.') if len(keys_arr) == 1: k = keys_arr[0] if default != None: if isinstance(conf, Mapping): val = conf.get(k, default) else: if default != None: return default raise Exception('') else: val = conf[k] if k in conf: del conf[k] if expand: self.deep_expand_vars(val, config=self._config, **kwargs) elif type(val) == str: while val.find('{') != -1: val = self.expand_vars(val, config=self._config, **kwargs) return val child = conf.get(keys_arr[0], {}) child_keys = '.'.join(keys_arr[1:]) return self.consume(child, child_keys, default=default, expand=expand, **kwargs) def get_version(self): if not self._version: repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha self._version = sha return self._version def _calc_num_jobs(self, schedC, num_conv_versions): schedulerC = copy.deepcopy(schedC) reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions) agents_per_job = self.consume(schedulerC, 'agents_per_job', 1) reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1) reps_per_job = reps_per_agent * agents_per_job jobs_needed = math.ceil(reps / reps_per_job) return jobs_needed def _reps_for_job(self, schedC, task_id, num_conv_versions): schedulerC = copy.deepcopy(schedC) num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions) reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions) if task_id == None: return list(range(0, reps)) reps_for_job = [[] for i in range(num_jobs)] for i in range(reps): reps_for_job[i % num_jobs].append(i) return reps_for_job[task_id] def run_local(self, filename, name, task_id, sweep_id): self.task_id = task_id config = self.load_config(filename, name) num_conv_versions = self._get_num_conv_versions(config) schedulerC = copy.deepcopy(config.get('scheduler', {})) rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions) self.sweep_id = sweep_id self._init_sweep(config) self._fork_processes(config, rep_ids) def run_slurm(self, filename, name): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = self.load_config(filename, name) slurmC = self.consume(config, 'slurm', expand=True) schedC = self.consume(config, 'scheduler') s_name = self.consume(slurmC, 'name') num_conv_versions = self._get_num_conv_versions(config) # Pre Validation runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True) if self.consume(slurmC, 'pre_validate', True): Runner = self.runners[runnerName] runner = Runner(self, config) runner.setup('PreDeployment-Validation') self._init_sweep(config) self.consume(config, 'wandb') python_script = 'main.py' sh_lines = ['#!/bin/bash'] sh_lines += self.consume(slurmC, 'sh_lines', []) if venv := self.consume(slurmC, 'venv', False): sh_lines += [f'source activate {venv}'] if self.consume(slurmC, 'xvfb', False): sh_lines += [f'xvfb-run python3 {python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'] else: sh_lines += [f'python3 {python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'] script = "\n".join(sh_lines) num_jobs = self._calc_num_jobs(schedC, num_conv_versions) last_job_idx = num_jobs - 1 num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) array = f'0-{last_job_idx}%{num_parallel_jobs}' job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) self.consume(config, 'name', '') self.consume(config, 'project', '') self.consume(config, 'vars', '') self.consume(config, 'grid', '') self.consume(config, 'ablative', '') if config != {}: print('[!] Unconsumed Config Parts:') pprint(config) if self.verify or config != {}: input(f'') job_id = job.submit() print(f'[>] Job submitted to slurm with id {job_id}') with open('job_hist.log', 'a') as f: f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n') def _fork_processes(self, config, rep_ids): schedC = self.consume(config, 'scheduler', {}) agents_per_job = self.consume(schedC, 'agents_per_job', 1) reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) node_reps = len(rep_ids) num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent)) if num_p == 1: print('[i] Running within main thread') self._run_process(config, rep_ids=rep_ids, p_ind=0) return procs = [] reps_done = 0 for p in range(num_p): print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})') num_reps = min(node_reps - reps_done, reps_per_agent) proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) proc.start() procs.append(proc) reps_done += num_reps for proc in procs: proc.join() print(f'[i] All threads/processes have terminated') def _init_sweep(self, config): if self.sweep_id == None and self.consume(config, 'sweep.enable', False): sweepC = self.consume(config, 'sweep') wandbC = copy.deepcopy(config['wandb']) project = self.consume(wandbC, 'project') self.sweep_id = wandb.sweep( sweep=sweepC, project=project ) def _run_process(self, orig_config, rep_ids, p_ind): config = copy.deepcopy(orig_config) if self.consume(config, 'sweep.enable', False): wandbC = copy.deepcopy(config['wandb']) wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: self.consume(config, 'sweep', {}) self._run_single(config, rep_ids, p_ind=p_ind) def _run_single(self, orig_config, rep_ids, p_ind): print(f'[P{p_ind}] I will work on reps {rep_ids}') runnerName = self.consume(orig_config, 'runner') project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name'))) Runner = self.runners[runnerName] if self.consume(orig_config, 'scheduler.bind_agent_to_core', False): os.sched_setaffinity(0, [p_ind % os.cpu_count()]) for r in rep_ids: self.run_id = r config = copy.deepcopy(orig_config) runnerConf = self._make_config_for_run(config, r) wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE')) if 'job_type' in wandbC and len(wandbC['job_type']) > 62: wandbC['job_type'] = "..."+wandbC['job_type'][-50:] retry = 5 while retry: try: with wandb.init( project=project, config=copy.deepcopy(runnerConf), reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), **wandbC ) as run: runner = Runner(self, runnerConf) runner.setup(wandbC['group']+wandbC['job_type']) runner.run(run) except wandb.errors.CommError as e: retry -= 1 if retry: print('Catched CommErr; retrying...') time.sleep(int(60*random.random())) else: print('Catched CommErr; not retrying') raise e else: retry = 0 if runnerConf != {}: msg = ('Config was not completely consumed: ', runnerConf) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = {} def _run_from_sweep(self, orig_config, p_ind): runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) project = self.consume(wandbC, 'project') Runner = self.runners[runnerName] if orig_config.consume('scheduler.bind_agent_to_core', False): os.sched_setaffinity(0, [p_ind % os.cpu_count()]) with wandb.init( project=project, reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), **wandbC ) as run: config = copy.deepcopy(orig_config) self.deep_update(config, wandb.config) run.config = copy.deepcopy(config) runner = Runner(self, config) runner.setup(wandbC['group']+wandbC['job_type']) runner.run(run) if config != {}: msg = ('Config was not completely consumed: ', config) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = {} def _make_configs_for_runs(self, config): c = copy.deepcopy(config) grid_versions = self._make_grid_versions(c) all_versions = self._make_ablative_versions(c, grid_versions) return all_versions def _get_num_conv_versions(self, config): return len(self._make_configs_for_runs(config)) def _make_config_for_run(self, config, r): all_versions = self._make_configs_for_runs(config) i = r % len(all_versions) print(f'[d] Running version {i}/{len(all_versions)} in run {r}') cur_conf = all_versions[i] if 'ablative' in cur_conf: del cur_conf['ablative'] return cur_conf def _make_grid_versions(self, config): if 'grid' in config: return params_combine(config, 'grid', itertools.product) return [config] def _make_ablative_versions(self, config, grid_versions): if 'ablative' in config: return grid_versions + ablative_expand(grid_versions) else: return grid_versions def from_args(self): import argparse parser = argparse.ArgumentParser() parser.add_argument("config_file", nargs='?', default=None) parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-t", "--task_id", default=None, type=int) parser.add_argument("--sweep_id", default=None, type=str) parser.add_argument("--ask_verify", action="store_true") args = parser.parse_args() print(f'[i] I have task_id {args.task_id}') print(f'[i] Running on version [git:{self.get_version()}]') if args.worker: raise Exception('Not yet implemented') assert args.config_file != None, 'Need to supply config file.' if args.slurm: if args.ask_verify: self.verify = True self.run_slurm(args.config_file, args.experiment) else: self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id) def params_combine(config: dict, key: str, iter_func): if iter_func is None: return [config] combined_configs = [] # convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the # value is the list of values tuple_dict = flatten_dict_to_tuple_keys(config[key]) _param_names = ['.'.join(t) for t in tuple_dict] param_lengths = map(len, tuple_dict.values()) # create a new config for each parameter setting for values in iter_func(*tuple_dict.values()): _config = copy.deepcopy(config) # Remove Grid/List Argument del _config[key] # Expand Grid/List Parameters for i, t in enumerate(tuple_dict.keys()): insert_deep_dictionary(d=_config, t=t, value=values[i]) _config = extend_config_name(_config, _param_names, values) combined_configs.append(_config) return combined_configs def ablative_expand(conf_list): combined_configs = [] for config in conf_list: tuple_dict = flatten_dict_to_tuple_keys(config['ablative']) _param_names = ['.'.join(t) for t in tuple_dict] for i, key in enumerate(tuple_dict): for val in tuple_dict[key]: _config = copy.deepcopy(config) insert_deep_dictionary( _config, key, val ) _config = extend_config_name(_config, [_param_names[i]], [val]) combined_configs.append(_config) return combined_configs def flatten_dict_to_tuple_keys(d: MutableMapping): flat_dict = {} for k, v in d.items(): if isinstance(v, MutableMapping): sub_dict = flatten_dict_to_tuple_keys(v) flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()}) elif isinstance(v, MutableSequence): flat_dict[(k,)] = v return flat_dict def insert_deep_dictionary(d: MutableMapping, t: tuple, value): if type(t) is tuple: if len(t) == 1: # tuple contains only one key d[t[0]] = value else: # tuple contains more than one key if t[0] not in d: d[t[0]] = dict() insert_deep_dictionary(d[t[0]], t[1:], value) else: d[t] = value def append_deep_dictionary(d: MutableMapping, t: tuple, value): if type(t) is tuple: if len(t) == 1: # tuple contains only one key if t[0] not in d: d[t[0]] = [] d[t[0]].append(value) else: # tuple contains more than one key if t[0] not in d: d[t[0]] = dict() append_deep_dictionary(d[t[0]], t[1:], value) else: d[t] = value def extend_config_name(config: dict, param_names: list, values: list) -> dict: _converted_name = convert_param_names(param_names, values) config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name return config def convert_param_names(_param_names: list, values: list) -> str: _converted_name = '_'.join("{}{}".format( shorten_param(k), v) for k, v in zip(_param_names, values)) # _converted_name = re.sub("[' \[\],()]", '', _converted_name) _converted_name = re.sub("[' ]", '', _converted_name) _converted_name = re.sub('["]', '', _converted_name) _converted_name = re.sub("[(\[]", '_', _converted_name) _converted_name = re.sub("[)\]]", '', _converted_name) _converted_name = re.sub("[,]", '_', _converted_name) return _converted_name def shorten_param(_param_name): name_parts = _param_name.split('.') shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1])) shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_'))) if shortened_parts: return shortened_parts + '.' + shortened_leaf else: return shortened_leaf class Slate_Runner(): def __init__(self, slate, config): self.slate = slate self.config = config def setup(self): pass def run(self, run): pass class Print_Config_Runner(Slate_Runner): def run(self, run): slate, config = self.slate, self.config pprint(config) print('---') pprint(slate.consume(config, '', expand=True)) for k in list(config.keys()): del config[k] class Void_Runner(Slate_Runner): def run(self, run): slate, config = self.slate, self.config for k in list(config.keys()): del config[k] class PDB_Runner(Slate_Runner): def run(self, run): d() if __name__ == '__main__': raise Exception('You are using it wrong...')