Slate/slate/slate.py

565 lines
21 KiB
Python
Raw Permalink Normal View History

2023-07-06 18:06:20 +02:00
import wandb
import yaml
import os
2023-07-07 13:10:06 +02:00
import math
2023-09-18 17:16:22 +02:00
import time
2023-07-06 18:06:20 +02:00
import random
import copy
import re
import itertools
from collections.abc import *
2023-07-06 18:06:20 +02:00
from functools import partial
2023-07-07 14:39:38 +02:00
from multiprocessing import Process
2023-07-09 16:12:38 +02:00
from threading import Thread
2023-07-12 11:07:33 +02:00
import git
import datetime
from pprint import pprint
2023-07-06 18:06:20 +02:00
import pdb
d = pdb.set_trace
2023-07-07 16:40:30 +02:00
REQUIRE_CONFIG_CONSUMED = False
2023-07-27 11:29:06 +02:00
DEFAULT_START_METHOD = 'fork'
DEFAULT_REINIT = True
2023-07-07 16:40:30 +02:00
2023-07-27 14:39:05 +02:00
Parallelization_Primitive = Process # Thread
2023-07-09 16:12:38 +02:00
2023-07-06 18:06:20 +02:00
try:
import pyslurm
except ImportError:
slurm_avaible = False
print('[!] Slurm not available.')
2023-07-06 18:06:20 +02:00
else:
slurm_avaible = True
class Slate():
def __init__(self, runners):
2023-07-27 11:33:45 +02:00
self.runners = {
2023-07-29 14:28:23 +02:00
'void': Void_Runner,
2023-07-27 11:33:45 +02:00
'printConfig': Print_Config_Runner,
'pdb': PDB_Runner,
}
self.runners.update(runners)
2023-07-12 12:23:18 +02:00
self._version = False
self.job_id = os.environ.get('SLURM_JOB_ID', False)
self.task_id = None
2023-07-31 15:38:46 +02:00
self.run_id = -1
self._tmp_path = os.path.expandvars('$TMP')
self.verify = False
2023-07-06 18:06:20 +02:00
def load_config(self, filename, name):
2024-08-16 12:54:29 +02:00
emptyStack = []
config, stack = self._load_config(filename, name, stack=emptyStack)
2023-07-06 18:06:20 +02:00
print('[i] Merged Configs: ', stack)
2023-07-09 16:12:38 +02:00
self._config = copy.deepcopy(config)
2023-07-06 18:06:20 +02:00
self.consume(config, 'vars', {})
return config
def _load_config(self, filename, name, stack):
2023-07-06 18:06:20 +02:00
stack.append(f'{filename}:{name}')
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
imports = doc['import'].split(',')
del doc['import']
for imp in imports:
if imp[0] == ' ':
imp = imp[1:]
if imp == "$":
imp = ':DEFAULT'
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
elif len(opt) == 1:
nested_name = opt[0]
else:
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child, stack = self._load_config(nested_path, nested_name, stack=stack)
doc = self.deep_update(child, doc)
return doc, stack
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
def deep_update(self, d, u, traverse_dot_notation=True):
2023-07-06 18:06:20 +02:00
for kstr, v in u.items():
if traverse_dot_notation:
ks = kstr.split('.')
else:
ks = [kstr]
2023-07-06 18:06:20 +02:00
head = d
for k in ks:
if k in ['parameters']:
traverse_dot_notation = False
2023-07-06 18:06:20 +02:00
last_head = head
if k not in head:
head[k] = {}
head = head[k]
if isinstance(v, Mapping):
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
2023-07-06 18:06:20 +02:00
else:
last_head[ks[-1]] = v
return d
def expand_vars(self, string, delta_desc='BASE', **kwargs):
2023-07-06 18:06:20 +02:00
if isinstance(string, str):
rand = int(random.random() * 99999999)
if string == '{rand}':
2023-07-10 11:26:46 +02:00
return rand
2023-07-31 15:38:46 +02:00
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id)
2023-07-06 18:06:20 +02:00
return string
def apply_nested(self, d, f):
for k, v in d.items():
if isinstance(v, dict):
self.apply_nested(v, f)
elif isinstance(v, list):
for i, e in enumerate(v):
2023-07-12 11:07:33 +02:00
ptr = {'PTR': d[k][i]}
self.apply_nested(ptr, f)
d[k][i] = ptr['PTR']
2023-07-06 18:06:20 +02:00
else:
d[k] = f(v)
def deep_expand_vars(self, dict, **kwargs):
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
2023-07-09 16:31:35 +02:00
def consume(self, conf, key, default=None, expand=False, **kwargs):
if key == '':
if expand:
self.deep_expand_vars(conf, config=self._config, **kwargs)
elif type(conf) == str:
while conf.find('{') != -1:
conf = self.expand_vars(conf, config=self._config, **kwargs)
return conf
2023-07-06 18:06:20 +02:00
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
if default is not None:
if isinstance(conf, Mapping):
2023-07-12 11:46:52 +02:00
val = conf.get(k, default)
else:
if default is not None:
2023-07-12 11:46:52 +02:00
return default
raise Exception('')
2023-07-06 18:06:20 +02:00
else:
val = conf[k]
if k in conf:
del conf[k]
2023-07-09 16:12:38 +02:00
2023-07-09 16:31:35 +02:00
if expand:
self.deep_expand_vars(val, config=self._config, **kwargs)
elif type(val) == str:
2023-07-09 16:16:24 +02:00
while val.find('{') != -1:
val = self.expand_vars(val, config=self._config, **kwargs)
2023-07-09 16:12:38 +02:00
2023-07-06 18:06:20 +02:00
return val
child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:])
2023-07-12 11:07:33 +02:00
return self.consume(child, child_keys, default=default, expand=expand, **kwargs)
2023-07-06 18:06:20 +02:00
2023-07-12 12:23:18 +02:00
def get_version(self):
if not self._version:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
self._version = sha
return self._version
2023-07-29 14:28:23 +02:00
def _calc_num_jobs(self, schedC, num_conv_versions):
schedulerC = copy.deepcopy(schedC)
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
2023-07-09 16:23:13 +02:00
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
2023-07-07 14:39:38 +02:00
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedC, task_id, num_conv_versions):
schedulerC = copy.deepcopy(schedC)
2023-07-29 14:28:23 +02:00
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
if task_id is None:
2023-07-07 16:40:30 +02:00
return list(range(0, reps))
reps_for_job = [[] for _ in range(num_jobs)]
2023-07-07 14:39:38 +02:00
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[task_id]
2023-07-07 14:39:38 +02:00
def _make_configs_for_runs(self, config_exp_pairs):
"""
Expand configurations across all provided experiments, grid, and ablation variants.
Parameters:
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
Returns:
list: A list of expanded configurations ready for execution.
"""
all_expanded_configs = []
for config, exp in config_exp_pairs:
config_data = self.load_config(config, exp)
grid_versions = self._make_grid_versions(config_data)
exp_variants = self._make_ablative_versions(config_data, grid_versions)
all_expanded_configs.extend(exp_variants)
return all_expanded_configs
def run_local(self, config_exp_pairs, task_id):
"""
Run all expanded configurations locally, handling all variants and their repetitions concurrently.
Parameters:
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
task_id (int): The task ID for the experiments.
"""
self.task_id = task_id
all_configs = self._make_configs_for_runs(config_exp_pairs)
num_conv_versions = len(all_configs)
schedulerC = copy.deepcopy(all_configs[0].get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
self._fork_processes(all_configs, rep_ids)
2023-07-06 18:06:20 +02:00
def run_slurm(self, original_config_exp_string, config_exp_pairs):
"""
Schedule all expanded configurations on SLURM within a single job.
2023-07-29 14:28:23 +02:00
Parameters:
original_config_exp_string (str): The original string of config:experiment pairs provided by the user.
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
"""
all_configs = self._make_configs_for_runs(config_exp_pairs)
slurmC = self.consume(all_configs[0], 'slurm', expand=True)
s_name = self.consume(slurmC, 'name')
2023-07-06 18:06:20 +02:00
python_script = 'main.py'
2023-07-07 12:32:12 +02:00
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
2023-07-06 18:06:20 +02:00
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
# Use the original config:experiment string to avoid verbosity
final_line = f'{python_script} {original_config_exp_string} -t $SLURM_ARRAY_TASK_ID'
2024-06-24 20:42:02 +02:00
if self.consume(slurmC, 'python_exec', False):
final_line = f'./omni_sif_python {final_line}'
else:
final_line = f'python3 {final_line}'
2024-02-02 18:16:46 +01:00
if self.consume(slurmC, 'xvfb', False):
final_line = f'xvfb-run {final_line}'
sh_lines.append(final_line)
2023-07-06 18:06:20 +02:00
script = "\n".join(sh_lines)
2023-07-06 18:06:20 +02:00
num_jobs = self._calc_num_jobs(all_configs[0].get('scheduler', {}), len(all_configs))
2023-07-06 18:06:20 +02:00
last_job_idx = num_jobs - 1
2023-07-07 12:25:54 +02:00
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
2023-07-06 18:06:20 +02:00
array = f'0-{last_job_idx}%{num_parallel_jobs}'
2023-07-07 12:27:56 +02:00
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
if self.verify:
input(f'[!] Press Enter to submit the job to SLURM.')
2023-07-06 18:06:20 +02:00
job_id = job.submit()
print(f'[>] Job submitted to SLURM with id {job_id}')
2023-07-06 18:06:20 +02:00
# Log file entry optimization
with open('job_hist.log', 'a') as f:
config_logs = {}
for config, exp in config_exp_pairs:
if config not in config_logs:
config_logs[config] = []
config_logs[config].append(exp)
for config, exps in config_logs.items():
exps_str = ",".join(exps)
f.write(f'{config}:{exps_str} submitted to SLURM with id {job_id}\n')
def _fork_processes(self, configs, rep_ids):
"""
Fork processes to run all expanded configurations concurrently.
Parameters:
configs (list): A list of expanded configurations.
rep_ids (list): A list of repetition identifiers for the configurations.
"""
schedC = self.consume(configs[0], 'scheduler', {})
2023-07-07 14:39:38 +02:00
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
2023-07-07 13:10:06 +02:00
2023-07-07 14:39:38 +02:00
node_reps = len(rep_ids)
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
2023-07-07 13:10:06 +02:00
if num_p == 1:
2023-07-09 16:44:09 +02:00
print('[i] Running within main thread')
self._run_process(configs, rep_ids=rep_ids, p_ind=0)
2023-07-07 14:39:38 +02:00
return
procs = []
reps_done = 0
2023-07-07 13:10:06 +02:00
2023-07-07 14:39:38 +02:00
for p in range(num_p):
print(f'[i] Spawning separate thread/process ({p+1}/{num_p})')
2023-07-07 14:39:38 +02:00
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done + num_reps))]
proc_configs = [configs[i % len(configs)] for i in proc_rep_ids] # Distribute configs across processes
proc = Parallelization_Primitive(target=partial(self._run_process, proc_configs, proc_rep_ids, p_ind=p))
2023-07-07 14:39:38 +02:00
proc.start()
procs.append(proc)
reps_done += num_reps
for proc in procs:
proc.join()
print('[i] All threads/processes have terminated')
def _run_process(self, orig_configs, rep_ids, p_ind):
"""
Run a single process for a subset of configurations.
Parameters:
configs (list): A list of configurations to run.
rep_ids (list): A list of repetition identifiers for the configurations.
p_ind (int): Process index.
"""
for r in rep_ids:
self.run_id = r
config = orig_configs[r % len(orig_configs)]
self._run_single(config, [r], p_ind=p_ind)
2023-07-07 14:39:38 +02:00
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
self._config = orig_config
runnerName = self.consume(orig_config, 'runner')
project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name')))
2023-07-07 14:39:38 +02:00
Runner = self.runners[runnerName]
2023-07-07 14:39:38 +02:00
if self.consume(orig_config, 'scheduler.bind_agent_to_core', False):
2023-09-02 19:55:35 +02:00
os.sched_setaffinity(0, [p_ind % os.cpu_count()])
2023-07-07 14:39:38 +02:00
for r in rep_ids:
2023-07-31 15:38:46 +02:00
self.run_id = r
2024-08-16 12:39:57 +02:00
runnerConf = copy.deepcopy(orig_config)
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
if 'job_type' in wandbC and len(wandbC['job_type']) > 62:
wandbC['job_type'] = "..."+wandbC['job_type'][-50:]
2023-09-18 17:16:22 +02:00
retry = 5
while retry:
try:
with wandb.init(
project=project,
config=copy.deepcopy(runnerConf),
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
**wandbC
) as run:
runner = Runner(self, runnerConf)
runner.setup(wandbC['group']+wandbC['job_type'])
2023-09-18 17:16:22 +02:00
runner.run(run)
except wandb.errors.CommError as e:
retry -= 1
if retry:
print('Caught CommErr; retrying...')
2023-09-18 17:16:22 +02:00
time.sleep(int(60*random.random()))
else:
print('Caught CommErr; not retrying')
2023-09-18 17:16:22 +02:00
raise e
2024-01-16 15:48:20 +01:00
else:
retry = 0
2023-07-07 14:39:38 +02:00
if runnerConf != {}:
msg = ('Config was not completely consumed: ', runnerConf)
2023-07-08 12:52:59 +02:00
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = {}
2023-07-07 14:39:38 +02:00
2023-07-29 14:28:23 +02:00
def _get_num_conv_versions(self, config):
return len(self._make_configs_for_runs(config))
def _make_grid_versions(self, config):
if 'grid' in config:
return params_combine(config, 'grid', itertools.product)
return [config]
def _make_ablative_versions(self, config, grid_versions):
if 'ablative' in config:
return grid_versions + ablative_expand(grid_versions)
else:
return grid_versions
2023-07-06 18:06:20 +02:00
def from_args(self):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config_experiments", nargs='+', help="List of config:experiment pairs")
2023-07-06 18:06:20 +02:00
parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-t", "--task_id", default=None, type=int)
parser.add_argument("--ask_verify", action="store_true")
2023-07-06 18:06:20 +02:00
args = parser.parse_args()
print(f'[i] I have task_id {args.task_id}')
2023-07-12 12:23:18 +02:00
print(f'[i] Running on version [git:{self.get_version()}]')
2023-07-09 17:16:50 +02:00
2023-07-06 18:06:20 +02:00
if args.worker:
raise Exception('Worker mode not yet implemented')
2023-07-06 18:06:20 +02:00
config_exp_pairs = []
for config_exp in args.config_experiments:
config, exps = config_exp.split(":")
exp_list = exps.split(",")
for exp in exp_list:
config_exp_pairs.append((config, exp))
2023-07-06 18:06:20 +02:00
if args.slurm:
if args.ask_verify:
self.verify = True
self.run_slurm(' '.join(args.config_experiments), config_exp_pairs)
2023-07-06 18:06:20 +02:00
else:
self.run_local(config_exp_pairs, args.task_id)
2023-07-06 18:06:20 +02:00
def params_combine(config: dict, key: str, iter_func):
if iter_func is None:
return [config]
combined_configs = []
tuple_dict = flatten_dict_to_tuple_keys(config[key])
_param_names = ['.'.join(t) for t in tuple_dict]
for values in iter_func(*tuple_dict.values()):
_config = copy.deepcopy(config)
del _config[key]
for i, t in enumerate(tuple_dict.keys()):
insert_deep_dictionary(d=_config, t=t, value=values[i])
_config = extend_config_name(_config, _param_names, values)
combined_configs.append(_config)
return combined_configs
def ablative_expand(conf_list):
combined_configs = []
for config in conf_list:
tuple_dict = flatten_dict_to_tuple_keys(config['ablative'])
_param_names = ['.'.join(t) for t in tuple_dict]
for i, key in enumerate(tuple_dict):
for val in tuple_dict[key]:
_config = copy.deepcopy(config)
insert_deep_dictionary(_config, key, val)
_config = extend_config_name(_config, [_param_names[i]], [val])
combined_configs.append(_config)
return combined_configs
def flatten_dict_to_tuple_keys(d: MutableMapping):
flat_dict = {}
for k, v in d.items():
if isinstance(v, MutableMapping):
sub_dict = flatten_dict_to_tuple_keys(v)
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
elif isinstance(v, MutableSequence):
flat_dict[(k,)] = v
return flat_dict
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
if type(t) is tuple:
if len(t) == 1:
d[t[0]] = value
else:
if t[0] not in d:
d[t[0]] = dict()
insert_deep_dictionary(d[t[0]], t[1:], value)
else:
d[t] = value
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
if type(t) is tuple:
2024-08-15 19:03:56 +02:00
if len(t) == 1:
if t[0] not in d:
d[t[0]] = []
d[t[0]].append(value)
else:
if t[0] not in d:
d[t[0]] = dict()
append_deep_dictionary(d[t[0]], t[1:], value)
else:
d[t] = value
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
_converted_name = convert_param_names(param_names, values)
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
return config
def convert_param_names(_param_names: list, values: list) -> str:
_converted_name = '_'.join("{}{}".format(shorten_param(k), v) for k, v in zip(_param_names, values))
_converted_name = re.sub("[' ]", '', _converted_name)
_converted_name = re.sub('["]', '', _converted_name)
_converted_name = re.sub("[(\[]", '_', _converted_name)
_converted_name = re.sub("[)\]]", '', _converted_name)
_converted_name = re.sub("[,]", '_', _converted_name)
return _converted_name
def shorten_param(_param_name):
name_parts = _param_name.split('.')
shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1]))
shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_')))
if shortened_parts:
return shortened_parts + '.' + shortened_leaf
else:
return shortened_leaf
class Slate_Runner():
def __init__(self, slate, config):
self.slate = slate
self.config = config
2024-08-16 12:39:57 +02:00
def setup(self, name):
pass
def run(self, run):
pass
class Print_Config_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
pprint(config)
print('---')
pprint(slate.consume(config, '', expand=True))
for k in list(config.keys()):
del config[k]
2023-07-06 18:20:37 +02:00
2023-07-29 14:28:23 +02:00
class Void_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
for k in list(config.keys()):
del config[k]
class PDB_Runner(Slate_Runner):
def run(self, run):
d()
2023-07-12 12:23:18 +02:00
2023-07-06 18:06:20 +02:00
if __name__ == '__main__':
raise Exception('You are using it wrong...')