Slate/slate/slate.py
2023-07-27 13:09:34 +02:00

385 lines
14 KiB
Python

import wandb
import yaml
import os
import math
import random
import copy
import collections.abc
from functools import partial
from multiprocessing import Process
from threading import Thread
import git
import datetime
from pprint import pprint
import pdb
d = pdb.set_trace
REQUIRE_CONFIG_CONSUMED = False
DEFAULT_START_METHOD = 'fork'
DEFAULT_REINIT = True
Parallelization_Primitive = Thread # Process
try:
import pyslurm
except ImportError:
slurm_avaible = False
print('[!] Slurm not avaible.')
else:
slurm_avaible = True
# TODO: Implement Testing
# TODO: Implement Ablative
class Slate():
def __init__(self, runners):
self.runners = {
'printConfig': Print_Config_Runner,
'pdb': PDB_Runner,
}
self.runners.update(runners)
self._version = False
self.sweep_id = None
def load_config(self, filename, name):
config, stack = self._load_config(filename, name)
print('[i] Merged Configs: ', stack)
self._config = copy.deepcopy(config)
self.consume(config, 'vars', {})
return config
def _load_config(self, filename, name, stack=[]):
stack.append(f'{filename}:{name}')
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
imports = doc['import'].split(',')
del doc['import']
for imp in imports:
if imp[0] == ' ':
imp = imp[1:]
if imp == "$":
imp = ':DEFAULT'
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
elif len(opt) == 1:
nested_name = opt[0]
else:
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child, stack = self._load_config(nested_path, nested_name, stack=stack)
doc = self.deep_update(child, doc)
return doc, stack
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update(self, d, u, traverse_dot_notation=True):
for kstr, v in u.items():
if traverse_dot_notation:
ks = kstr.split('.')
else:
ks = [kstr]
head = d
for k in ks:
if k in ['parameters']:
traverse_dot_notation = False
last_head = head
if k not in head:
head[k] = {}
head = head[k]
if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
else:
last_head[ks[-1]] = v
return d
def expand_vars(self, string, **kwargs):
if isinstance(string, str):
rand = int(random.random()*99999999)
if string == '{rand}':
return rand
return string.format(**kwargs, rand=rand)
return string
def apply_nested(self, d, f):
for k, v in d.items():
if isinstance(v, dict):
self.apply_nested(v, f)
elif isinstance(v, list):
for i, e in enumerate(v):
ptr = {'PTR': d[k][i]}
self.apply_nested(ptr, f)
d[k][i] = ptr['PTR']
else:
d[k] = f(v)
def deep_expand_vars(self, dict, **kwargs):
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
def consume(self, conf, key, default=None, expand=False, **kwargs):
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
if default != None:
if isinstance(conf, collections.abc.Mapping):
val = conf.get(k, default)
else:
if default != None:
return default
raise Exception('')
else:
val = conf[k]
if k in conf:
del conf[k]
if expand:
self.deep_expand_vars(val, config=self._config, **kwargs)
elif type(val) == str:
while val.find('{') != -1:
val = self.expand_vars(val, config=self._config, **kwargs)
return val
child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:])
return self.consume(child, child_keys, default=default, expand=expand, **kwargs)
def get_version(self):
if not self._version:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
self._version = sha
return self._version
def _calc_num_jobs(self, schedC):
schedulerC = copy.deepcopy(schedC)
reps = self.consume(schedulerC, 'repetitions', 1)
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedC, job_id):
schedulerC = copy.deepcopy(schedC)
num_jobs = self._calc_num_jobs(schedulerC)
reps = self.consume(schedulerC, 'repetitions', 1)
if job_id == None:
return list(range(0, reps))
reps_for_job = [[]] * num_jobs
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id-1]
def run_local(self, filename, name, job_id, sweep_id):
config = self.load_config(filename, name)
schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id)
self.sweep_id = sweep_id
self._init_sweep(config)
self._fork_processes(config, rep_ids)
def run_slurm(self, filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm', expand=True)
schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name')
# Pre Validation
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
if self.consume(slurmC, 'pre_validate', True):
Runner = self.runners[runnerName]
runner = Runner(self, config)
runner.setup()
self._init_sweep(config)
self.consume(config, 'wandb')
python_script = 'main.py'
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID -s {self.sweep_id}']
script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC)
last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
array = f'0-{last_job_idx}%{num_parallel_jobs}'
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
job_id = job.submit()
print(f'[>] Job submitted to slurm with id {job_id}')
with open('job_hist.log', 'a') as f:
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
def _fork_processes(self, config, rep_ids):
schedC = self.consume(config, 'scheduler')
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
node_reps = len(rep_ids)
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
if num_p == 1:
print('[i] Running within main thread')
self._run_process(config, rep_ids=rep_ids, p_ind=0)
return
procs = []
reps_done = 0
for p in range(num_p):
print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})')
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
proc.start()
procs.append(proc)
reps_done += num_reps
for proc in procs:
proc.join()
print(f'[i] All threads/processes have terminated')
def _init_sweep(self, config):
pprint(config['sweep'])
if self.sweep_id == None and self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep')
wandbC = copy.deepcopy(config['wandb'])
project = self.consume(wandbC, 'project')
self.sweep_id = wandb.sweep(
sweep=sweepC,
project=project
)
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False):
wandbC = copy.deepcopy(config['wandb'])
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else:
self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind)
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
Runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
with wandb.init(
project=project,
config=copy.deepcopy(config),
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
**wandbC
) as run:
runner = Runner(self, config)
runner.setup()
runner.run(run)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
Runner = self.runners[runnerName]
with wandb.init(
project=project,
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
**wandbC
) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
run.config = copy.deepcopy(config)
runner = Runner(self, config)
runner.setup()
runner.run(run)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
def from_args(self):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config_file", nargs='?', default=None)
parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_id", default=None, type=int)
parser.add_argument("--sweep_id", default=None, type=str)
args = parser.parse_args()
print(f'[i] I have job_id {args.job_id}')
print(f'[i] Running on version [git:{self.get_version()}]')
if args.worker:
raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.'
if args.slurm:
self.run_slurm(args.config_file, args.experiment)
else:
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
class Slate_Runner():
def __init__(self, slate, config):
self.slate = slate
self.config = config
def setup(self):
pass
def run(self, run):
pass
class Print_Config_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
ptr = {'ptr': config}
pprint(config)
print('---')
pprint(slate.consume(ptr, 'ptr', expand=True))
for k in list(config.keys()):
del config[k]
class PDB_Runner(Slate_Runner):
def run(self, run):
d()
if __name__ == '__main__':
raise Exception('You are using it wrong...')