import wandb import yaml import os import math import random import copy import collections.abc from functools import partial from multiprocessing import Process from threading import Thread import git import datetime from pprint import pprint import pdb d = pdb.set_trace REQUIRE_CONFIG_CONSUMED = False DEFAULT_START_METHOD = 'fork' DEFAULT_REINIT = True Parallelization_Primitive = Thread # Process try: import pyslurm except ImportError: slurm_avaible = False print('[!] Slurm not avaible.') else: slurm_avaible = True # TODO: Implement Testing # TODO: Implement Ablative class Slate(): def __init__(self, runners): self.runners = { 'printConfig': Print_Config_Runner, 'pdb': PDB_Runner, } self.runners.update(runners) self._version = False self.sweep_id = None def load_config(self, filename, name): config, stack = self._load_config(filename, name) print('[i] Merged Configs: ', stack) self._config = copy.deepcopy(config) self.consume(config, 'vars', {}) return config def _load_config(self, filename, name, stack=[]): stack.append(f'{filename}:{name}') with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: if 'name' in doc: if doc['name'] == name: if 'import' in doc: imports = doc['import'].split(',') del doc['import'] for imp in imports: if imp[0] == ' ': imp = imp[1:] if imp == "$": imp = ':DEFAULT' rel_path, *opt = imp.split(':') if len(opt) == 0: nested_name = 'DEFAULT' elif len(opt) == 1: nested_name = opt[0] else: raise Exception('Malformed import statement. Must be , , for file:DEFAULT or for :DEFAULT.') nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename child, stack = self._load_config(nested_path, nested_name, stack=stack) doc = self.deep_update(child, doc) return doc, stack raise Exception(f'Unable to find experiment <{name}> in <{filename}>') def deep_update(self, d, u, traverse_dot_notation=True): for kstr, v in u.items(): if traverse_dot_notation: ks = kstr.split('.') else: ks = [kstr] head = d for k in ks: if k in ['parameters']: traverse_dot_notation = False last_head = head if k not in head: head[k] = {} head = head[k] if isinstance(v, collections.abc.Mapping): last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation) else: last_head[ks[-1]] = v return d def expand_vars(self, string, **kwargs): if isinstance(string, str): rand = int(random.random()*99999999) if string == '{rand}': return rand return string.format(**kwargs, rand=rand) return string def apply_nested(self, d, f): for k, v in d.items(): if isinstance(v, dict): self.apply_nested(v, f) elif isinstance(v, list): for i, e in enumerate(v): ptr = {'PTR': d[k][i]} self.apply_nested(ptr, f) d[k][i] = ptr['PTR'] else: d[k] = f(v) def deep_expand_vars(self, dict, **kwargs): self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) def consume(self, conf, key, default=None, expand=False, **kwargs): keys_arr = key.split('.') if len(keys_arr) == 1: k = keys_arr[0] if default != None: if isinstance(conf, collections.abc.Mapping): val = conf.get(k, default) else: if default != None: return default raise Exception('') else: val = conf[k] if k in conf: del conf[k] if expand: self.deep_expand_vars(val, config=self._config, **kwargs) elif type(val) == str: while val.find('{') != -1: val = self.expand_vars(val, config=self._config, **kwargs) return val child = conf.get(keys_arr[0], {}) child_keys = '.'.join(keys_arr[1:]) return self.consume(child, child_keys, default=default, expand=expand, **kwargs) def get_version(self): if not self._version: repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha self._version = sha return self._version def _calc_num_jobs(self, schedC): schedulerC = copy.deepcopy(schedC) reps = self.consume(schedulerC, 'repetitions', 1) agents_per_job = self.consume(schedulerC, 'agents_per_job', 1) reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1) reps_per_job = reps_per_agent * agents_per_job jobs_needed = math.ceil(reps / reps_per_job) return jobs_needed def _reps_for_job(self, schedC, job_id): schedulerC = copy.deepcopy(schedC) num_jobs = self._calc_num_jobs(schedulerC) reps = self.consume(schedulerC, 'repetitions', 1) if job_id == None: return list(range(0, reps)) reps_for_job = [[]] * num_jobs for i in range(reps): reps_for_job[i % num_jobs].append(i) return reps_for_job[job_id-1] def run_local(self, filename, name, job_id, sweep_id): config = self.load_config(filename, name) schedulerC = copy.deepcopy(config.get('scheduler', {})) rep_ids = self._reps_for_job(schedulerC, job_id) self.sweep_id = sweep_id self._init_sweep(config) self._fork_processes(config, rep_ids) def run_slurm(self, filename, name): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = self.load_config(filename, name) slurmC = self.consume(config, 'slurm', expand=True) schedC = self.consume(config, 'scheduler') s_name = self.consume(slurmC, 'name') # Pre Validation runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True) if self.consume(slurmC, 'pre_validate', True): Runner = self.runners[runnerName] runner = Runner(self, config) runner.setup() self._init_sweep(config) self.consume(config, 'wandb') python_script = 'main.py' sh_lines = ['#!/bin/bash'] sh_lines += self.consume(slurmC, 'sh_lines', []) if venv := self.consume(slurmC, 'venv', False): sh_lines += [f'source activate {venv}'] sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID -s {self.sweep_id}'] script = "\n".join(sh_lines) num_jobs = self._calc_num_jobs(schedC) last_job_idx = num_jobs - 1 num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) array = f'0-{last_job_idx}%{num_parallel_jobs}' job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) job_id = job.submit() print(f'[>] Job submitted to slurm with id {job_id}') with open('job_hist.log', 'a') as f: f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n') def _fork_processes(self, config, rep_ids): schedC = self.consume(config, 'scheduler') agents_per_job = self.consume(schedC, 'agents_per_job', 1) reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) node_reps = len(rep_ids) num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent)) if num_p == 1: print('[i] Running within main thread') self._run_process(config, rep_ids=rep_ids, p_ind=0) return procs = [] reps_done = 0 for p in range(num_p): print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})') num_reps = min(node_reps - reps_done, reps_per_agent) proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) proc.start() procs.append(proc) reps_done += num_reps for proc in procs: proc.join() print(f'[i] All threads/processes have terminated') def _init_sweep(self, config): if self.sweep_id == None and self.consume(config, 'sweep.enable', False): sweepC = self.consume(config, 'sweep') wandbC = copy.deepcopy(config['wandb']) project = self.consume(wandbC, 'project') self.sweep_id = wandb.sweep( sweep=sweepC, project=project ) def _run_process(self, orig_config, rep_ids, p_ind): config = copy.deepcopy(orig_config) if self.consume(config, 'sweep.enable', False): wandbC = copy.deepcopy(config['wandb']) wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: self.consume(config, 'sweep', {}) self._run_single(config, rep_ids, p_ind=p_ind) def _run_single(self, orig_config, rep_ids, p_ind): print(f'[P{p_ind}] I will work on reps {rep_ids}') runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) project = self.consume(wandbC, 'project') Runner = self.runners[runnerName] for r in rep_ids: config = copy.deepcopy(orig_config) with wandb.init( project=project, config=copy.deepcopy(config), reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)), **wandbC ) as run: runner = Runner(self, config) runner.setup() runner.run(run) if config != {}: msg = ('Config was not completely consumed: ', config) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = config def _run_from_sweep(self, orig_config, p_ind): runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) project = self.consume(wandbC, 'project') Runner = self.runners[runnerName] with wandb.init( project=project, reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)), **wandbC ) as run: config = copy.deepcopy(orig_config) self.deep_update(config, wandb.config) run.config = copy.deepcopy(config) runner = Runner(self, config) runner.setup() runner.run(run) if config != {}: msg = ('Config was not completely consumed: ', config) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = config def from_args(self): import argparse parser = argparse.ArgumentParser() parser.add_argument("config_file", nargs='?', default=None) parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-j", "--job_id", default=None, type=int) parser.add_argument("--sweep_id", default=None, type=str) args = parser.parse_args() print(f'[i] I have job_id {args.job_id}') print(f'[i] Running on version [git:{self.get_version()}]') if args.worker: raise Exception('Not yet implemented') assert args.config_file != None, 'Need to supply config file.' if args.slurm: self.run_slurm(args.config_file, args.experiment) else: self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id) class Slate_Runner(): def __init__(self, slate, config): self.slate = slate self.config = config def setup(self): pass def run(self, run): pass class Print_Config_Runner(Slate_Runner): def run(self, run): slate, config = self.slate, self.config ptr = {'ptr': config} pprint(config) print('---') pprint(slate.consume(ptr, 'ptr', expand=True)) for k in list(config.keys()): del config[k] class PDB_Runner(Slate_Runner): def run(self, run): d() if __name__ == '__main__': raise Exception('You are using it wrong...')