Slate/slate/slate.py
2023-07-09 16:18:03 +02:00

294 lines
10 KiB
Python

import wandb
import yaml
import os
import math
import random
import copy
import collections.abc
from functools import partial
from multiprocessing import Process
from threading import Thread
import pdb
d = pdb.set_trace
REQUIRE_CONFIG_CONSUMED = False
Parallelization_Primitive = Thread # Process
try:
import pyslurm
except ImportError:
slurm_avaible = False
print('[!] Slurm not avaible.')
else:
slurm_avaible = True
# TODO: Implement Slurm
# TODO: Implement Parallel
# TODO: Implement Testing
# TODO: Implement Ablative
# TODO: Implement PCA
class Slate():
def __init__(self, runners):
self.runners = runners
self.runners['printConfig'] = print_config_runner
def load_config(self, filename, name):
config, stack = self._load_config(filename, name)
print('[i] Merged Configs: ', stack)
self._config = copy.deepcopy(config)
self.consume(config, 'vars', {})
return config
def _load_config(self, filename, name, stack=[]):
stack.append(f'{filename}:{name}')
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
imports = doc['import'].split(',')
del doc['import']
for imp in imports:
if imp[0] == ' ':
imp = imp[1:]
if imp == "$":
imp = ':DEFAULT'
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
elif len(opt) == 1:
nested_name = opt[0]
else:
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child, stack = self._load_config(nested_path, nested_name, stack=stack)
doc = self.deep_update(child, doc)
return doc, stack
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update(self, d, u):
for kstr, v in u.items():
ks = kstr.split('.')
head = d
for k in ks:
last_head = head
if k not in head:
head[k] = {}
head = head[k]
if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v)
else:
last_head[ks[-1]] = v
return d
def expand_vars(self, string, **kwargs):
if isinstance(string, str):
if string == '{rand}':
return int(random.random()*99999999)
return string.format(**kwargs, rand=int(random.random()*99999999), srand=srand)
return string
def apply_nested(self, d, f):
for k, v in d.items():
if isinstance(v, dict):
self.apply_nested(v, f)
elif isinstance(v, list):
for i, e in enumerate(v):
self.apply_nested({'PTR': d[k][i]}, f)
else:
d[k] = f(v)
def deep_expand_vars(self, dict, **kwargs):
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
def consume(self, conf, key, default=None, **kwargs):
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
if default != None:
val = conf.get(k, default)
else:
val = conf[k]
if k in conf:
del conf[k]
if type(val) == str:
while val.find('{') != -1:
val = self.expand_vars(val, config=self._config, **kwargs)
return val
child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:])
return self.consume(child, child_keys, default=default, **kwargs)
def _calc_num_jobs(self, schedulerC):
reps = schedulerC.get('repetitions', 1)
agents_per_job = schedulerC.get('agents_per_job', 1)
reps_per_agent = schedulerC.get('reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedulerC, job_id):
reps = schedulerC.get('repetitions', 1)
if job_id == None:
return list(range(0, reps))
num_jobs = self._calc_num_jobs(schedulerC)
reps_for_job = [[]] * num_jobs
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id]
def run_local(self, filename, name, job_id):
config = self.load_config(filename, name)
schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id)
self._fork_processes(config, rep_ids)
def run_slurm(self, filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm')
schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name')
python_script = 'main.py'
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC)
last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
array = f'0-{last_job_idx}%{num_parallel_jobs}'
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}')
def _fork_processes(self, config, rep_ids):
schedC = self.consume(config, 'scheduler')
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
node_reps = len(rep_ids)
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
if num_p == 1:
self._run_single(config, rep_ids=rep_ids, p_ind=0)
return
procs = []
reps_done = 0
for p in range(num_p):
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
proc.start()
procs.append(proc)
reps_done += num_reps
for proc in procs:
proc.join()
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep')
project = config['wandb']['project']
sweep_id = wandb.sweep(
sweep=sweepC,
project=project
)
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else:
self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind)
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
with wandb.init(
project=project,
config=copy.deepcopy(config),
**wandbC
) as run:
runner(self, run, config)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
with wandb.init(
project=project,
**wandbC
) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
runner(self, run, config)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
def from_args(self):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config_file", nargs='?', default=None)
parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_id", default=None, type=int)
args = parser.parse_args()
if args.worker:
raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.'
if args.slurm:
self.run_slurm(args.config_file, args.experiment)
else:
self.run_local(args.config_file, args.experiment, args.job_id)
def print_config_runner(slate, run, config):
print(config)
for k in list(config.keys()):
del config[k]
if __name__ == '__main__':
raise Exception('You are using it wrong...')