import wandb import yaml import os import math import random import copy import collections.abc from functools import partial from multiprocessing import Process from threading import Thread import git import datetime import pdb d = pdb.set_trace REQUIRE_CONFIG_CONSUMED = False WANDB_START_METHOD = 'fork' REINIT = True Parallelization_Primitive = Thread # Process try: import pyslurm except ImportError: slurm_avaible = False print('[!] Slurm not avaible.') else: slurm_avaible = True # TODO: Implement Slurm # TODO: Implement Parallel # TODO: Implement Testing # TODO: Implement Ablative # TODO: Implement PCA class Slate(): def __init__(self, runners): self.runners = runners self.runners['printConfig'] = print_config_runner def load_config(self, filename, name): config, stack = self._load_config(filename, name) print('[i] Merged Configs: ', stack) self._config = copy.deepcopy(config) self.consume(config, 'vars', {}) return config def _load_config(self, filename, name, stack=[]): stack.append(f'{filename}:{name}') with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: if 'name' in doc: if doc['name'] == name: if 'import' in doc: imports = doc['import'].split(',') del doc['import'] for imp in imports: if imp[0] == ' ': imp = imp[1:] if imp == "$": imp = ':DEFAULT' rel_path, *opt = imp.split(':') if len(opt) == 0: nested_name = 'DEFAULT' elif len(opt) == 1: nested_name = opt[0] else: raise Exception('Malformed import statement. Must be , , for file:DEFAULT or for :DEFAULT.') nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename child, stack = self._load_config(nested_path, nested_name, stack=stack) doc = self.deep_update(child, doc) return doc, stack raise Exception(f'Unable to find experiment <{name}> in <{filename}>') def deep_update(self, d, u): for kstr, v in u.items(): ks = kstr.split('.') head = d for k in ks: last_head = head if k not in head: head[k] = {} head = head[k] if isinstance(v, collections.abc.Mapping): last_head[ks[-1]] = self.deep_update(d.get(k, {}), v) else: last_head[ks[-1]] = v return d def expand_vars(self, string, **kwargs): if isinstance(string, str): rand = int(random.random()*99999999) if string == '{rand}': return rand return string.format(**kwargs, rand=rand) return string def apply_nested(self, d, f): for k, v in d.items(): if isinstance(v, dict): self.apply_nested(v, f) elif isinstance(v, list): for i, e in enumerate(v): ptr = {'PTR': d[k][i]} self.apply_nested(ptr, f) d[k][i] = ptr['PTR'] else: d[k] = f(v) def deep_expand_vars(self, dict, **kwargs): self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) def consume(self, conf, key, default=None, expand=False, **kwargs): keys_arr = key.split('.') if len(keys_arr) == 1: k = keys_arr[0] if default != None: val = conf.get(k, default) else: val = conf[k] if k in conf: del conf[k] if expand: self.deep_expand_vars(val, config=self._config, **kwargs) elif type(val) == str: while val.find('{') != -1: val = self.expand_vars(val, config=self._config, **kwargs) return val child = conf.get(keys_arr[0], {}) child_keys = '.'.join(keys_arr[1:]) return self.consume(child, child_keys, default=default, expand=expand, **kwargs) def _calc_num_jobs(self, schedC): schedulerC = copy.deepcopy(schedC) reps = self.consume(schedulerC, 'repetitions', 1) agents_per_job = self.consume(schedulerC, 'agents_per_job', 1) reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1) reps_per_job = reps_per_agent * agents_per_job jobs_needed = math.ceil(reps / reps_per_job) return jobs_needed def _reps_for_job(self, schedC, job_id): schedulerC = copy.deepcopy(schedC) num_jobs = self._calc_num_jobs(schedulerC) reps = self.consume(schedulerC, 'repetitions', 1) if job_id == None: return list(range(0, reps)) reps_for_job = [[]] * num_jobs for i in range(reps): reps_for_job[i % num_jobs].append(i) return reps_for_job[job_id-1] def run_local(self, filename, name, sha, job_id): config = self.load_config(filename, name) schedulerC = copy.deepcopy(config.get('scheduler', {})) rep_ids = self._reps_for_job(schedulerC, job_id) self._fork_processes(config, rep_ids) def run_slurm(self, filename, name, sha): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = self.load_config(filename, name) slurmC = self.consume(config, 'slurm', expand=True) schedC = self.consume(config, 'scheduler') s_name = self.consume(slurmC, 'name') python_script = 'main.py' sh_lines = ['#!/bin/bash'] sh_lines += self.consume(slurmC, 'sh_lines', []) if venv := self.consume(slurmC, 'venv', False): sh_lines += [f'source activate {venv}'] sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] script = "\n".join(sh_lines) num_jobs = self._calc_num_jobs(schedC) last_job_idx = num_jobs - 1 num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) array = f'0-{last_job_idx}%{num_parallel_jobs}' job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) job_id = job.submit() print(f'[i] Job submitted to slurm with id {job_id}') with open('job_hist.log', 'w') as f: f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{sha}] at {datetime.datetime.now()}') def _fork_processes(self, config, rep_ids): schedC = self.consume(config, 'scheduler') agents_per_job = self.consume(schedC, 'agents_per_job', 1) reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) node_reps = len(rep_ids) num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent)) if num_p == 1: print('[i] Running within main thread') self._run_single(config, rep_ids=rep_ids, p_ind=0) return procs = [] reps_done = 0 for p in range(num_p): print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})') num_reps = min(node_reps - reps_done, reps_per_agent) proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) proc.start() procs.append(proc) reps_done += num_reps for proc in procs: proc.join() print(f'[i] All threads/processes have terminated') def _run_process(self, orig_config, rep_ids, p_ind): config = copy.deepcopy(orig_config) if self.consume(config, 'sweep.enable', False): sweepC = self.consume(config, 'sweep') project = self.consume(copy.deepcopy(config['wandb']), 'project') sweep_id = wandb.sweep( sweep=sweepC, project=project, settings=wandb.Settings(start_method=WANDB_START_METHOD) ) wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: self.consume(config, 'sweep', {}) self._run_single(config, rep_ids, p_ind=p_ind) def _run_single(self, orig_config, rep_ids, p_ind): print(f'[P{p_ind}] I will work on reps {rep_ids}') runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) project = self.consume(wandbC, 'project') runner = self.runners[runnerName] for r in rep_ids: config = copy.deepcopy(orig_config) with wandb.init( project=project, config=copy.deepcopy(config), reinit=REINIT, settings=wandb.Settings(start_method=WANDB_START_METHOD), **wandbC ) as run: runner(self, run, config) if config != {}: msg = ('Config was not completely consumed: ', config) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = config def _run_from_sweep(self, orig_config, p_ind): runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) project = self.consume(wandbC, 'project') runner = self.runners[runnerName] with wandb.init( project=project, reinit=REINIT, settings=wandb.Settings(start_method=WANDB_START_METHOD), **wandbC ) as run: config = copy.deepcopy(orig_config) self.deep_update(config, wandb.config) runner(self, run, config) if config != {}: msg = ('Config was not completely consumed: ', config) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = config def from_args(self): import argparse parser = argparse.ArgumentParser() parser.add_argument("config_file", nargs='?', default=None) parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-j", "--job_id", default=None, type=int) args = parser.parse_args() repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha print(f'[i] I have job_id {args.job_id}') print(f'[i] Running on version [git:{sha}]') if args.worker: raise Exception('Not yet implemented') assert args.config_file != None, 'Need to supply config file.' if args.slurm: self.run_slurm(args.config_file, args.experiment, sha) else: self.run_local(args.config_file, args.experiment, sha, args.job_id) def print_config_runner(slate, run, config): from pprint import pprint ptr = {'ptr': config} pprint(slate.consume(ptr, 'ptr'), expand=True) for k in list(config.keys()): del config[k] if __name__ == '__main__': raise Exception('You are using it wrong...')