Slate/slate/slate.py

294 lines
10 KiB
Python
Raw Normal View History

2023-07-06 18:06:20 +02:00
import wandb
import yaml
import os
2023-07-07 13:10:06 +02:00
import math
2023-07-06 18:06:20 +02:00
import random
import copy
import collections.abc
from functools import partial
2023-07-07 14:39:38 +02:00
from multiprocessing import Process
2023-07-09 16:12:38 +02:00
from threading import Thread
2023-07-06 18:06:20 +02:00
import pdb
d = pdb.set_trace
2023-07-07 16:40:30 +02:00
REQUIRE_CONFIG_CONSUMED = False
2023-07-09 16:12:38 +02:00
Parallelization_Primitive = Thread # Process
2023-07-06 18:06:20 +02:00
try:
import pyslurm
except ImportError:
slurm_avaible = False
2023-07-07 12:10:46 +02:00
print('[!] Slurm not avaible.')
2023-07-06 18:06:20 +02:00
else:
slurm_avaible = True
# TODO: Implement Slurm
# TODO: Implement Parallel
# TODO: Implement Testing
# TODO: Implement Ablative
# TODO: Implement PCA
class Slate():
def __init__(self, runners):
self.runners = runners
2023-07-06 18:20:37 +02:00
self.runners['printConfig'] = print_config_runner
2023-07-06 18:06:20 +02:00
def load_config(self, filename, name):
config, stack = self._load_config(filename, name)
print('[i] Merged Configs: ', stack)
2023-07-09 16:12:38 +02:00
self._config = copy.deepcopy(config)
2023-07-06 18:06:20 +02:00
self.consume(config, 'vars', {})
return config
def _load_config(self, filename, name, stack=[]):
stack.append(f'{filename}:{name}')
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
imports = doc['import'].split(',')
del doc['import']
for imp in imports:
if imp[0] == ' ':
imp = imp[1:]
if imp == "$":
imp = ':DEFAULT'
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
elif len(opt) == 1:
nested_name = opt[0]
else:
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child, stack = self._load_config(nested_path, nested_name, stack=stack)
doc = self.deep_update(child, doc)
return doc, stack
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update(self, d, u):
for kstr, v in u.items():
ks = kstr.split('.')
head = d
for k in ks:
last_head = head
if k not in head:
head[k] = {}
head = head[k]
if isinstance(v, collections.abc.Mapping):
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v)
else:
last_head[ks[-1]] = v
return d
def expand_vars(self, string, **kwargs):
if isinstance(string, str):
if string == '{rand}':
return int(random.random()*99999999)
2023-07-09 16:18:47 +02:00
return string.format(**kwargs, rand=int(random.random()*99999999))
2023-07-06 18:06:20 +02:00
return string
def apply_nested(self, d, f):
for k, v in d.items():
if isinstance(v, dict):
self.apply_nested(v, f)
elif isinstance(v, list):
for i, e in enumerate(v):
self.apply_nested({'PTR': d[k][i]}, f)
else:
d[k] = f(v)
def deep_expand_vars(self, dict, **kwargs):
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
2023-07-09 16:12:38 +02:00
def consume(self, conf, key, default=None, **kwargs):
2023-07-06 18:06:20 +02:00
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
if default != None:
val = conf.get(k, default)
else:
val = conf[k]
if k in conf:
del conf[k]
2023-07-09 16:12:38 +02:00
2023-07-09 16:18:03 +02:00
if type(val) == str:
2023-07-09 16:16:24 +02:00
while val.find('{') != -1:
val = self.expand_vars(val, config=self._config, **kwargs)
2023-07-09 16:12:38 +02:00
2023-07-06 18:06:20 +02:00
return val
child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:])
2023-07-09 16:12:38 +02:00
return self.consume(child, child_keys, default=default, **kwargs)
2023-07-06 18:06:20 +02:00
2023-07-07 14:39:38 +02:00
def _calc_num_jobs(self, schedulerC):
reps = schedulerC.get('repetitions', 1)
agents_per_job = schedulerC.get('agents_per_job', 1)
reps_per_agent = schedulerC.get('reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedulerC, job_id):
reps = schedulerC.get('repetitions', 1)
2023-07-07 15:27:16 +02:00
if job_id == None:
2023-07-07 16:40:30 +02:00
return list(range(0, reps))
2023-07-07 14:39:38 +02:00
num_jobs = self._calc_num_jobs(schedulerC)
reps_for_job = [[]] * num_jobs
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id]
2023-07-07 15:27:16 +02:00
def run_local(self, filename, name, job_id):
2023-07-06 18:06:20 +02:00
config = self.load_config(filename, name)
2023-07-07 14:39:38 +02:00
schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id)
self._fork_processes(config, rep_ids)
2023-07-06 18:06:20 +02:00
def run_slurm(self, filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm')
2023-07-07 14:39:38 +02:00
schedC = self.consume(config, 'scheduler')
2023-07-06 18:06:20 +02:00
s_name = self.consume(slurmC, 'name')
python_script = 'main.py'
2023-07-07 12:32:12 +02:00
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
2023-07-06 18:06:20 +02:00
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
2023-07-07 12:38:07 +02:00
script = "\n".join(sh_lines)
2023-07-06 18:06:20 +02:00
2023-07-07 14:39:38 +02:00
num_jobs = self._calc_num_jobs(schedC)
2023-07-06 18:06:20 +02:00
last_job_idx = num_jobs - 1
2023-07-07 12:25:54 +02:00
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
2023-07-06 18:06:20 +02:00
array = f'0-{last_job_idx}%{num_parallel_jobs}'
2023-07-07 12:27:56 +02:00
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
2023-07-06 18:06:20 +02:00
job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}')
2023-07-07 14:39:38 +02:00
def _fork_processes(self, config, rep_ids):
2023-07-07 13:10:06 +02:00
schedC = self.consume(config, 'scheduler')
2023-07-07 14:39:38 +02:00
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
2023-07-07 13:10:06 +02:00
2023-07-07 14:39:38 +02:00
node_reps = len(rep_ids)
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
2023-07-07 13:10:06 +02:00
if num_p == 1:
2023-07-07 14:39:38 +02:00
self._run_single(config, rep_ids=rep_ids, p_ind=0)
return
procs = []
2023-07-07 13:10:06 +02:00
2023-07-07 14:39:38 +02:00
reps_done = 0
2023-07-07 13:10:06 +02:00
2023-07-07 14:39:38 +02:00
for p in range(num_p):
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
2023-07-09 16:12:38 +02:00
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
2023-07-07 14:39:38 +02:00
proc.start()
procs.append(proc)
reps_done += num_reps
for proc in procs:
proc.join()
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep')
project = config['wandb']['project']
sweep_id = wandb.sweep(
sweep=sweepC,
project=project
)
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else:
self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind)
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
with wandb.init(
project=project,
config=copy.deepcopy(config),
2023-07-07 14:39:38 +02:00
**wandbC
) as run:
runner(self, run, config)
2023-07-08 12:52:59 +02:00
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
2023-07-07 14:39:38 +02:00
orig_config = config
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
2023-07-06 18:06:20 +02:00
2023-07-06 23:14:54 +02:00
runner = self.runners[runnerName]
2023-07-06 18:06:20 +02:00
with wandb.init(
2023-07-07 14:39:38 +02:00
project=project,
2023-07-06 18:06:20 +02:00
**wandbC
) as run:
2023-07-07 14:39:38 +02:00
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
2023-07-06 23:16:57 +02:00
runner(self, run, config)
2023-07-06 18:06:20 +02:00
2023-07-08 12:52:59 +02:00
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
2023-07-07 14:39:38 +02:00
orig_config = config
2023-07-06 18:06:20 +02:00
def from_args(self):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config_file", nargs='?', default=None)
parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true")
2023-07-07 15:27:16 +02:00
parser.add_argument("-j", "--job_id", default=None, type=int)
2023-07-06 18:06:20 +02:00
args = parser.parse_args()
if args.worker:
raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.'
if args.slurm:
self.run_slurm(args.config_file, args.experiment)
else:
2023-07-07 14:39:38 +02:00
self.run_local(args.config_file, args.experiment, args.job_id)
2023-07-06 18:06:20 +02:00
2023-07-06 18:20:37 +02:00
def print_config_runner(slate, run, config):
print(config)
for k in list(config.keys()):
del config[k]
2023-07-06 18:06:20 +02:00
if __name__ == '__main__':
raise Exception('You are using it wrong...')