Slate/slate/slate.py
2023-07-12 12:23:18 +02:00

340 lines
12 KiB
Python

import wandb
import yaml
import os
import math
import random
import copy
import collections.abc
from functools import partial
from multiprocessing import Process
from threading import Thread
import git
import datetime
import pdb
d = pdb.set_trace
REQUIRE_CONFIG_CONSUMED = False
WANDB_START_METHOD = 'fork'
REINIT = True
Parallelization_Primitive = Thread # Process
try:
import pyslurm
except ImportError:
slurm_avaible = False
print('[!] Slurm not avaible.')
else:
slurm_avaible = True
# TODO: Implement Slurm
# TODO: Implement Parallel
# TODO: Implement Testing
# TODO: Implement Ablative
# TODO: Implement PCA
class Slate():
def __init__(self, runners):
self.runners = runners
self.runners['printConfig'] = print_config_runner
self.runners['pdb'] = pdb_runner
self._version = False
def load_config(self, filename, name):
config, stack = self._load_config(filename, name)
print('[i] Merged Configs: ', stack)
self._config = copy.deepcopy(config)
self.consume(config, 'vars', {})
return config
def _load_config(self, filename, name, stack=[]):
stack.append(f'{filename}:{name}')
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
imports = doc['import'].split(',')
del doc['import']
for imp in imports:
if imp[0] == ' ':
imp = imp[1:]
if imp == "$":
imp = ':DEFAULT'
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
elif len(opt) == 1:
nested_name = opt[0]
else:
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child, stack = self._load_config(nested_path, nested_name, stack=stack)
doc = self.deep_update(child, doc)
return doc, stack
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update(self, d, u):
for kstr, v in u.items():
ks = kstr.split('.')
head = d
for k in ks:
last_head = head
if k not in head:
head[k] = {}
head = head[k]
if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v)
else:
last_head[ks[-1]] = v
return d
def expand_vars(self, string, **kwargs):
if isinstance(string, str):
rand = int(random.random()*99999999)
if string == '{rand}':
return rand
return string.format(**kwargs, rand=rand)
return string
def apply_nested(self, d, f):
for k, v in d.items():
if isinstance(v, dict):
self.apply_nested(v, f)
elif isinstance(v, list):
for i, e in enumerate(v):
ptr = {'PTR': d[k][i]}
self.apply_nested(ptr, f)
d[k][i] = ptr['PTR']
else:
d[k] = f(v)
def deep_expand_vars(self, dict, **kwargs):
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
def consume(self, conf, key, default=None, expand=False, **kwargs):
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
if default != None:
if isinstance(conf, collections.abc.Mapping):
val = conf.get(k, default)
else:
if default != None:
return default
raise Exception('')
else:
val = conf[k]
if k in conf:
del conf[k]
if expand:
self.deep_expand_vars(val, config=self._config, **kwargs)
elif type(val) == str:
while val.find('{') != -1:
val = self.expand_vars(val, config=self._config, **kwargs)
return val
child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:])
return self.consume(child, child_keys, default=default, expand=expand, **kwargs)
def get_version(self):
if not self._version:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
self._version = sha
return self._version
def _calc_num_jobs(self, schedC):
schedulerC = copy.deepcopy(schedC)
reps = self.consume(schedulerC, 'repetitions', 1)
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedC, job_id):
schedulerC = copy.deepcopy(schedC)
num_jobs = self._calc_num_jobs(schedulerC)
reps = self.consume(schedulerC, 'repetitions', 1)
if job_id == None:
return list(range(0, reps))
reps_for_job = [[]] * num_jobs
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id-1]
def run_local(self, filename, name, job_id):
config = self.load_config(filename, name)
schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id)
self._fork_processes(config, rep_ids)
def run_slurm(self, filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm', expand=True)
schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name')
python_script = 'main.py'
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC)
last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
array = f'0-{last_job_idx}%{num_parallel_jobs}'
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
job_id = job.submit()
print(f'[>] Job submitted to slurm with id {job_id}')
with open('job_hist.log', 'a') as f:
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
def _fork_processes(self, config, rep_ids):
schedC = self.consume(config, 'scheduler')
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
node_reps = len(rep_ids)
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
if num_p == 1:
print('[i] Running within main thread')
self._run_single(config, rep_ids=rep_ids, p_ind=0)
return
procs = []
reps_done = 0
for p in range(num_p):
print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})')
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
proc.start()
procs.append(proc)
reps_done += num_reps
for proc in procs:
proc.join()
print(f'[i] All threads/processes have terminated')
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep')
project = self.consume(copy.deepcopy(config['wandb']), 'project')
sweep_id = wandb.sweep(
sweep=sweepC,
project=project,
settings=wandb.Settings(start_method=WANDB_START_METHOD)
)
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else:
self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind)
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
with wandb.init(
project=project,
config=copy.deepcopy(config),
reinit=REINIT,
settings=wandb.Settings(start_method=WANDB_START_METHOD),
**wandbC
) as run:
runner(self, run, config)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
with wandb.init(
project=project,
reinit=REINIT,
settings=wandb.Settings(start_method=WANDB_START_METHOD),
**wandbC
) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
runner(self, run, config)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
def from_args(self):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config_file", nargs='?', default=None)
parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_id", default=None, type=int)
args = parser.parse_args()
print(f'[i] I have job_id {args.job_id}')
print(f'[i] Running on version [git:{self.get_version()}]')
if args.worker:
raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.'
if args.slurm:
self.run_slurm(args.config_file, args.experiment)
else:
self.run_local(args.config_file, args.experiment, args.job_id)
def print_config_runner(slate, run, config):
from pprint import pprint
ptr = {'ptr': config}
pprint(config)
print('---')
pprint(slate.consume(ptr, 'ptr', expand=True))
for k in list(config.keys()):
del config[k]
def pdb_runner(slate, run, config):
d()
if __name__ == '__main__':
raise Exception('You are using it wrong...')