380 lines
14 KiB
Python
380 lines
14 KiB
Python
import wandb
|
|
import yaml
|
|
import os
|
|
import math
|
|
import random
|
|
import copy
|
|
import collections.abc
|
|
from functools import partial
|
|
from multiprocessing import Process
|
|
from threading import Thread
|
|
import git
|
|
import datetime
|
|
from pprint import pprint
|
|
|
|
import pdb
|
|
d = pdb.set_trace
|
|
|
|
REQUIRE_CONFIG_CONSUMED = False
|
|
DEFAULT_START_METHOD = 'fork'
|
|
DEFAULT_REINIT = True
|
|
|
|
Parallelization_Primitive = Thread # Process
|
|
|
|
try:
|
|
import pyslurm
|
|
except ImportError:
|
|
slurm_avaible = False
|
|
print('[!] Slurm not avaible.')
|
|
else:
|
|
slurm_avaible = True
|
|
|
|
# TODO: Implement Testing
|
|
# TODO: Implement Ablative
|
|
|
|
|
|
class Slate():
|
|
def __init__(self, runners):
|
|
self.runners = {
|
|
'printConfig': Print_Config_Runner,
|
|
'pdb': PDB_Runner,
|
|
}
|
|
self.runners.update(runners)
|
|
self._version = False
|
|
self.sweep_id = None
|
|
|
|
def load_config(self, filename, name):
|
|
config, stack = self._load_config(filename, name)
|
|
print('[i] Merged Configs: ', stack)
|
|
self._config = copy.deepcopy(config)
|
|
self.consume(config, 'vars', {})
|
|
return config
|
|
|
|
def _load_config(self, filename, name, stack=[]):
|
|
stack.append(f'{filename}:{name}')
|
|
with open(filename, 'r') as f:
|
|
docs = yaml.safe_load_all(f)
|
|
for doc in docs:
|
|
if 'name' in doc:
|
|
if doc['name'] == name:
|
|
if 'import' in doc:
|
|
imports = doc['import'].split(',')
|
|
del doc['import']
|
|
for imp in imports:
|
|
if imp[0] == ' ':
|
|
imp = imp[1:]
|
|
if imp == "$":
|
|
imp = ':DEFAULT'
|
|
rel_path, *opt = imp.split(':')
|
|
if len(opt) == 0:
|
|
nested_name = 'DEFAULT'
|
|
elif len(opt) == 1:
|
|
nested_name = opt[0]
|
|
else:
|
|
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
|
|
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
|
child, stack = self._load_config(nested_path, nested_name, stack=stack)
|
|
doc = self.deep_update(child, doc)
|
|
return doc, stack
|
|
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
|
|
|
def deep_update(self, d, u):
|
|
for kstr, v in u.items():
|
|
ks = kstr.split('.')
|
|
head = d
|
|
for k in ks:
|
|
last_head = head
|
|
if k not in head:
|
|
head[k] = {}
|
|
head = head[k]
|
|
if isinstance(v, collections.abc.Mapping):
|
|
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v)
|
|
else:
|
|
last_head[ks[-1]] = v
|
|
return d
|
|
|
|
def expand_vars(self, string, **kwargs):
|
|
if isinstance(string, str):
|
|
rand = int(random.random()*99999999)
|
|
if string == '{rand}':
|
|
return rand
|
|
return string.format(**kwargs, rand=rand)
|
|
return string
|
|
|
|
def apply_nested(self, d, f):
|
|
for k, v in d.items():
|
|
if isinstance(v, dict):
|
|
self.apply_nested(v, f)
|
|
elif isinstance(v, list):
|
|
for i, e in enumerate(v):
|
|
ptr = {'PTR': d[k][i]}
|
|
self.apply_nested(ptr, f)
|
|
d[k][i] = ptr['PTR']
|
|
else:
|
|
d[k] = f(v)
|
|
|
|
def deep_expand_vars(self, dict, **kwargs):
|
|
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
|
|
|
def consume(self, conf, key, default=None, expand=False, **kwargs):
|
|
keys_arr = key.split('.')
|
|
if len(keys_arr) == 1:
|
|
k = keys_arr[0]
|
|
if default != None:
|
|
if isinstance(conf, collections.abc.Mapping):
|
|
val = conf.get(k, default)
|
|
else:
|
|
if default != None:
|
|
return default
|
|
raise Exception('')
|
|
else:
|
|
val = conf[k]
|
|
if k in conf:
|
|
del conf[k]
|
|
|
|
if expand:
|
|
self.deep_expand_vars(val, config=self._config, **kwargs)
|
|
elif type(val) == str:
|
|
while val.find('{') != -1:
|
|
val = self.expand_vars(val, config=self._config, **kwargs)
|
|
|
|
return val
|
|
child = conf.get(keys_arr[0], {})
|
|
child_keys = '.'.join(keys_arr[1:])
|
|
return self.consume(child, child_keys, default=default, expand=expand, **kwargs)
|
|
|
|
def get_version(self):
|
|
if not self._version:
|
|
repo = git.Repo(search_parent_directories=True)
|
|
sha = repo.head.object.hexsha
|
|
self._version = sha
|
|
return self._version
|
|
|
|
def _calc_num_jobs(self, schedC):
|
|
schedulerC = copy.deepcopy(schedC)
|
|
reps = self.consume(schedulerC, 'repetitions', 1)
|
|
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
|
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
|
reps_per_job = reps_per_agent * agents_per_job
|
|
jobs_needed = math.ceil(reps / reps_per_job)
|
|
return jobs_needed
|
|
|
|
def _reps_for_job(self, schedC, job_id):
|
|
schedulerC = copy.deepcopy(schedC)
|
|
num_jobs = self._calc_num_jobs(schedulerC)
|
|
reps = self.consume(schedulerC, 'repetitions', 1)
|
|
if job_id == None:
|
|
return list(range(0, reps))
|
|
reps_for_job = [[]] * num_jobs
|
|
for i in range(reps):
|
|
reps_for_job[i % num_jobs].append(i)
|
|
return reps_for_job[job_id-1]
|
|
|
|
def run_local(self, filename, name, job_id, sweep_id):
|
|
config = self.load_config(filename, name)
|
|
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
|
rep_ids = self._reps_for_job(schedulerC, job_id)
|
|
self.sweep_id = sweep_id
|
|
self._init_sweep(config)
|
|
self._fork_processes(config, rep_ids)
|
|
|
|
def run_slurm(self, filename, name):
|
|
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
|
config = self.load_config(filename, name)
|
|
slurmC = self.consume(config, 'slurm', expand=True)
|
|
schedC = self.consume(config, 'scheduler')
|
|
s_name = self.consume(slurmC, 'name')
|
|
|
|
# Pre Validation
|
|
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
|
|
if self.consume(slurmC, 'pre_validate', True):
|
|
Runner = self.runners[runnerName]
|
|
runner = Runner(self, config)
|
|
runner.setup()
|
|
|
|
self._init_sweep(config)
|
|
self.consume(config, 'wandb')
|
|
|
|
python_script = 'main.py'
|
|
sh_lines = ['#!/bin/bash']
|
|
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
|
if venv := self.consume(slurmC, 'venv', False):
|
|
sh_lines += [f'source activate {venv}']
|
|
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID -s {self.sweep_id}']
|
|
script = "\n".join(sh_lines)
|
|
|
|
num_jobs = self._calc_num_jobs(schedC)
|
|
|
|
last_job_idx = num_jobs - 1
|
|
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
|
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
|
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
|
|
job_id = job.submit()
|
|
print(f'[>] Job submitted to slurm with id {job_id}')
|
|
with open('job_hist.log', 'a') as f:
|
|
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
|
|
|
|
def _fork_processes(self, config, rep_ids):
|
|
schedC = self.consume(config, 'scheduler')
|
|
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
|
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
|
|
|
node_reps = len(rep_ids)
|
|
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
|
|
|
|
if num_p == 1:
|
|
print('[i] Running within main thread')
|
|
self._run_process(config, rep_ids=rep_ids, p_ind=0)
|
|
return
|
|
|
|
procs = []
|
|
|
|
reps_done = 0
|
|
|
|
for p in range(num_p):
|
|
print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})')
|
|
num_reps = min(node_reps - reps_done, reps_per_agent)
|
|
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
|
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
|
proc.start()
|
|
procs.append(proc)
|
|
reps_done += num_reps
|
|
|
|
for proc in procs:
|
|
proc.join()
|
|
print(f'[i] All threads/processes have terminated')
|
|
|
|
def _init_sweep(self, config):
|
|
if self.sweep_id == None and self.consume(config, 'sweep.enable', False):
|
|
sweepC = self.consume(config, 'sweep')
|
|
wandbC = copy.deepcopy(config['wandb'])
|
|
project = self.consume(wandbC, 'project')
|
|
|
|
self.sweep_id = wandb.sweep(
|
|
sweep=sweepC,
|
|
project=project,
|
|
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
|
|
)
|
|
|
|
def _run_process(self, orig_config, rep_ids, p_ind):
|
|
config = copy.deepcopy(orig_config)
|
|
if self.consume(config, 'sweep.enable', False):
|
|
wandbC = copy.deepcopy(config['wandb'])
|
|
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
|
else:
|
|
self.consume(config, 'sweep', {})
|
|
self._run_single(config, rep_ids, p_ind=p_ind)
|
|
|
|
def _run_single(self, orig_config, rep_ids, p_ind):
|
|
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
|
project = self.consume(wandbC, 'project')
|
|
|
|
Runner = self.runners[runnerName]
|
|
|
|
for r in rep_ids:
|
|
config = copy.deepcopy(orig_config)
|
|
with wandb.init(
|
|
project=project,
|
|
config=copy.deepcopy(config),
|
|
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
|
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
|
|
**wandbC
|
|
) as run:
|
|
runner = Runner(self, config)
|
|
runner.setup()
|
|
runner.run(run)
|
|
|
|
if config != {}:
|
|
msg = ('Config was not completely consumed: ', config)
|
|
if REQUIRE_CONFIG_CONSUMED:
|
|
raise Exception(msg)
|
|
else:
|
|
print(msg)
|
|
orig_config = config
|
|
|
|
def _run_from_sweep(self, orig_config, p_ind):
|
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
|
project = self.consume(wandbC, 'project')
|
|
|
|
Runner = self.runners[runnerName]
|
|
|
|
with wandb.init(
|
|
project=project,
|
|
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
|
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
|
|
**wandbC
|
|
) as run:
|
|
config = copy.deepcopy(orig_config)
|
|
self.deep_update(config, wandb.config)
|
|
run.config = copy.deepcopy(config)
|
|
runner = Runner(self, config)
|
|
runner.setup()
|
|
runner.run(run)
|
|
|
|
if config != {}:
|
|
msg = ('Config was not completely consumed: ', config)
|
|
if REQUIRE_CONFIG_CONSUMED:
|
|
raise Exception(msg)
|
|
else:
|
|
print(msg)
|
|
orig_config = config
|
|
|
|
def from_args(self):
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("config_file", nargs='?', default=None)
|
|
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
|
parser.add_argument("-s", "--slurm", action="store_true")
|
|
parser.add_argument("-w", "--worker", action="store_true")
|
|
parser.add_argument("-j", "--job_id", default=None, type=int)
|
|
parser.add_argument("--sweep_id", default=None, type=str)
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(f'[i] I have job_id {args.job_id}')
|
|
print(f'[i] Running on version [git:{self.get_version()}]')
|
|
|
|
if args.worker:
|
|
raise Exception('Not yet implemented')
|
|
|
|
assert args.config_file != None, 'Need to supply config file.'
|
|
if args.slurm:
|
|
self.run_slurm(args.config_file, args.experiment)
|
|
else:
|
|
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
|
|
|
|
|
|
class Slate_Runner():
|
|
def __init__(self, slate, config):
|
|
self.slate = slate
|
|
self.config = config
|
|
|
|
def setup(self):
|
|
pass
|
|
|
|
def run(self, run):
|
|
pass
|
|
|
|
|
|
class Print_Config_Runner(Slate_Runner):
|
|
def run(self, run):
|
|
slate, config = self.slate, self.config
|
|
|
|
ptr = {'ptr': config}
|
|
pprint(config)
|
|
print('---')
|
|
pprint(slate.consume(ptr, 'ptr', expand=True))
|
|
for k in list(config.keys()):
|
|
del config[k]
|
|
|
|
|
|
class PDB_Runner(Slate_Runner):
|
|
def run(self, run):
|
|
d()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
raise Exception('You are using it wrong...')
|