import wandb import yaml import os import math import random import copy import collections.abc from functools import partial from multiprocessing import Process from threading import Thread import pdb d = pdb.set_trace REQUIRE_CONFIG_CONSUMED = False Parallelization_Primitive = Thread # Process try: import pyslurm except ImportError: slurm_avaible = False print('[!] Slurm not avaible.') else: slurm_avaible = True # TODO: Implement Slurm # TODO: Implement Parallel # TODO: Implement Testing # TODO: Implement Ablative # TODO: Implement PCA class Slate(): def __init__(self, runners): self.runners = runners self.runners['printConfig'] = print_config_runner def load_config(self, filename, name): config, stack = self._load_config(filename, name) print('[i] Merged Configs: ', stack) self._config = copy.deepcopy(config) self.consume(config, 'vars', {}) return config def _load_config(self, filename, name, stack=[]): stack.append(f'{filename}:{name}') with open(filename, 'r') as f: docs = yaml.safe_load_all(f) for doc in docs: if 'name' in doc: if doc['name'] == name: if 'import' in doc: imports = doc['import'].split(',') del doc['import'] for imp in imports: if imp[0] == ' ': imp = imp[1:] if imp == "$": imp = ':DEFAULT' rel_path, *opt = imp.split(':') if len(opt) == 0: nested_name = 'DEFAULT' elif len(opt) == 1: nested_name = opt[0] else: raise Exception('Malformed import statement. Must be , , for file:DEFAULT or for :DEFAULT.') nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename child, stack = self._load_config(nested_path, nested_name, stack=stack) doc = self.deep_update(child, doc) return doc, stack raise Exception(f'Unable to find experiment <{name}> in <{filename}>') def deep_update(self, d, u): for kstr, v in u.items(): ks = kstr.split('.') head = d for k in ks: last_head = head if k not in head: head[k] = {} head = head[k] if isinstance(v, collections.abc.Mapping): last_head[ks[-1]] = self.deep_update(d.get(k, {}), v) else: last_head[ks[-1]] = v return d def expand_vars(self, string, **kwargs): if isinstance(string, str): if string == '{rand}': return int(random.random()*99999999) return string.format(**kwargs, rand=int(random.random()*99999999)) return string def apply_nested(self, d, f): for k, v in d.items(): if isinstance(v, dict): self.apply_nested(v, f) elif isinstance(v, list): for i, e in enumerate(v): self.apply_nested({'PTR': d[k][i]}, f) else: d[k] = f(v) def deep_expand_vars(self, dict, **kwargs): self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) def consume(self, conf, key, default=None, expand=False, **kwargs): keys_arr = key.split('.') if len(keys_arr) == 1: k = keys_arr[0] if default != None: val = conf.get(k, default) else: val = conf[k] if k in conf: del conf[k] if expand: self.deep_expand_vars(val, config=self._config, **kwargs) elif type(val) == str: while val.find('{') != -1: val = self.expand_vars(val, config=self._config, **kwargs) return val child = conf.get(keys_arr[0], {}) child_keys = '.'.join(keys_arr[1:]) return self.consume(child, child_keys, default=default, **kwargs) def _calc_num_jobs(self, schedC): schedulerC = copy.deepcopy(schedC) reps = self.consume(schedulerC, 'repetitions', 1) agents_per_job = self.consume(schedulerC, 'agents_per_job', 1) reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1) reps_per_job = reps_per_agent * agents_per_job jobs_needed = math.ceil(reps / reps_per_job) return jobs_needed def _reps_for_job(self, schedC, job_id): schedulerC = copy.deepcopy(schedC) reps = self.consume(schedulerC, 'repetitions', 1) if job_id == None: return list(range(0, reps)) num_jobs = self._calc_num_jobs(schedulerC) reps_for_job = [[]] * num_jobs for i in range(reps): reps_for_job[i % num_jobs].append(i) return reps_for_job[job_id] def run_local(self, filename, name, job_id): config = self.load_config(filename, name) schedulerC = copy.deepcopy(config.get('scheduler', {})) rep_ids = self._reps_for_job(schedulerC, job_id) self._fork_processes(config, rep_ids) def run_slurm(self, filename, name): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = self.load_config(filename, name) slurmC = self.consume(config, 'slurm') schedC = self.consume(config, 'scheduler') s_name = self.consume(slurmC, 'name') python_script = 'main.py' sh_lines = ['#!/bin/bash'] sh_lines += self.consume(slurmC, 'sh_lines', []) if venv := self.consume(slurmC, 'venv', False): sh_lines += [f'source activate {venv}'] sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] script = "\n".join(sh_lines) num_jobs = self._calc_num_jobs(schedC) last_job_idx = num_jobs - 1 num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) array = f'0-{last_job_idx}%{num_parallel_jobs}' job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) job_id = job.submit() print(f'[i] Job submitted to slurm with id {job_id}') def _fork_processes(self, config, rep_ids): schedC = self.consume(config, 'scheduler') agents_per_job = self.consume(schedC, 'agents_per_job', 1) reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) node_reps = len(rep_ids) num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent)) if num_p == 1: self._run_single(config, rep_ids=rep_ids, p_ind=0) return procs = [] reps_done = 0 for p in range(num_p): num_reps = min(node_reps - reps_done, reps_per_agent) proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) proc.start() procs.append(proc) reps_done += num_reps for proc in procs: proc.join() def _run_process(self, orig_config, rep_ids, p_ind): config = copy.deepcopy(orig_config) if self.consume(config, 'sweep.enable', False): sweepC = self.consume(config, 'sweep') project = config['wandb']['project'] sweep_id = wandb.sweep( sweep=sweepC, project=project ) wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: self.consume(config, 'sweep', {}) self._run_single(config, rep_ids, p_ind=p_ind) def _run_single(self, orig_config, rep_ids, p_ind): print(f'[P{p_ind}] I will work on reps {rep_ids}') runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}) project = self.consume(wandbC, 'project') runner = self.runners[runnerName] for r in rep_ids: config = copy.deepcopy(orig_config) with wandb.init( project=project, config=copy.deepcopy(config), **wandbC ) as run: runner(self, run, config) if config != {}: msg = ('Config was not completely consumed: ', config) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = config def _run_from_sweep(self, orig_config, p_ind): runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}) project = self.consume(wandbC, 'project') runner = self.runners[runnerName] with wandb.init( project=project, **wandbC ) as run: config = copy.deepcopy(orig_config) self.deep_update(config, wandb.config) runner(self, run, config) if config != {}: msg = ('Config was not completely consumed: ', config) if REQUIRE_CONFIG_CONSUMED: raise Exception(msg) else: print(msg) orig_config = config def from_args(self): import argparse parser = argparse.ArgumentParser() parser.add_argument("config_file", nargs='?', default=None) parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-j", "--job_id", default=None, type=int) args = parser.parse_args() if args.worker: raise Exception('Not yet implemented') assert args.config_file != None, 'Need to supply config file.' if args.slurm: self.run_slurm(args.config_file, args.experiment) else: self.run_local(args.config_file, args.experiment, args.job_id) def print_config_runner(slate, run, config): print(config) for k in list(config.keys()): del config[k] if __name__ == '__main__': raise Exception('You are using it wrong...')