Maybe better cli interface und ressource allocation

This commit is contained in:
Dominik Moritz Roth 2024-08-15 18:54:36 +02:00
parent 69c79bdb6f
commit ebca76e680

View File

@ -28,14 +28,10 @@ try:
import pyslurm import pyslurm
except ImportError: except ImportError:
slurm_avaible = False slurm_avaible = False
print('[!] Slurm not avaible.') print('[!] Slurm not available.')
else: else:
slurm_avaible = True slurm_avaible = True
# TODO: Implement Testing
# TODO: Implement Ablative
class Slate(): class Slate():
def __init__(self, runners): def __init__(self, runners):
self.runners = { self.runners = {
@ -49,7 +45,6 @@ class Slate():
self.task_id = None self.task_id = None
self.run_id = -1 self.run_id = -1
self._tmp_path = os.path.expandvars('$TMP') self._tmp_path = os.path.expandvars('$TMP')
self.sweep_id = None
self.verify = False self.verify = False
def load_config(self, filename, name): def load_config(self, filename, name):
@ -109,7 +104,7 @@ class Slate():
def expand_vars(self, string, delta_desc='BASE', **kwargs): def expand_vars(self, string, delta_desc='BASE', **kwargs):
if isinstance(string, str): if isinstance(string, str):
rand = int(random.random()*99999999) rand = int(random.random() * 99999999)
if string == '{rand}': if string == '{rand}':
return rand return rand
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id) return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id)
@ -141,11 +136,11 @@ class Slate():
keys_arr = key.split('.') keys_arr = key.split('.')
if len(keys_arr) == 1: if len(keys_arr) == 1:
k = keys_arr[0] k = keys_arr[0]
if default != None: if default is not None:
if isinstance(conf, Mapping): if isinstance(conf, Mapping):
val = conf.get(k, default) val = conf.get(k, default)
else: else:
if default != None: if default is not None:
return default return default
raise Exception('') raise Exception('')
else: else:
@ -173,7 +168,7 @@ class Slate():
def _calc_num_jobs(self, schedC, num_conv_versions): def _calc_num_jobs(self, schedC, num_conv_versions):
schedulerC = copy.deepcopy(schedC) schedulerC = copy.deepcopy(schedC)
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions) reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1) agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1) reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job reps_per_job = reps_per_agent * agents_per_job
@ -183,84 +178,115 @@ class Slate():
def _reps_for_job(self, schedC, task_id, num_conv_versions): def _reps_for_job(self, schedC, task_id, num_conv_versions):
schedulerC = copy.deepcopy(schedC) schedulerC = copy.deepcopy(schedC)
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions) num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions) reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
if task_id == None: if task_id is None:
return list(range(0, reps)) return list(range(0, reps))
reps_for_job = [[] for i in range(num_jobs)] reps_for_job = [[] for _ in range(num_jobs)]
for i in range(reps): for i in range(reps):
reps_for_job[i % num_jobs].append(i) reps_for_job[i % num_jobs].append(i)
return reps_for_job[task_id] return reps_for_job[task_id]
def run_local(self, filename, name, task_id, sweep_id): def _make_configs_for_runs(self, config_exp_pairs):
"""
Expand configurations across all provided experiments, grid, and ablation variants.
Parameters:
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
Returns:
list: A list of expanded configurations ready for execution.
"""
all_expanded_configs = []
for config, exp in config_exp_pairs:
config_data = self.load_config(config, exp)
grid_versions = self._make_grid_versions(config_data)
exp_variants = self._make_ablative_versions(config_data, grid_versions)
all_expanded_configs.extend(exp_variants)
return all_expanded_configs
def run_local(self, config_exp_pairs, task_id):
"""
Run all expanded configurations locally, handling all variants and their repetitions concurrently.
Parameters:
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
task_id (int): The task ID for the experiments.
"""
self.task_id = task_id self.task_id = task_id
config = self.load_config(filename, name) all_configs = self._make_configs_for_runs(config_exp_pairs)
num_conv_versions = self._get_num_conv_versions(config)
schedulerC = copy.deepcopy(config.get('scheduler', {})) num_conv_versions = len(all_configs)
schedulerC = copy.deepcopy(all_configs[0].get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions) rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
self.sweep_id = sweep_id self._fork_processes(all_configs, rep_ids)
self._init_sweep(config)
self._fork_processes(config, rep_ids)
def run_slurm(self, filename, name): def run_slurm(self, original_config_exp_string, config_exp_pairs):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' """
config = self.load_config(filename, name) Schedule all expanded configurations on SLURM within a single job.
slurmC = self.consume(config, 'slurm', expand=True)
schedC = self.consume(config, 'scheduler') Parameters:
original_config_exp_string (str): The original string of config:experiment pairs provided by the user.
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
"""
all_configs = self._make_configs_for_runs(config_exp_pairs)
slurmC = self.consume(all_configs[0], 'slurm', expand=True)
s_name = self.consume(slurmC, 'name') s_name = self.consume(slurmC, 'name')
num_conv_versions = self._get_num_conv_versions(config)
# Pre Validation
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
if self.consume(slurmC, 'pre_validate', True):
Runner = self.runners[runnerName]
runner = Runner(self, config)
runner.setup('PreDeployment-Validation')
self._init_sweep(config)
self.consume(config, 'wandb')
python_script = 'main.py' python_script = 'main.py'
sh_lines = ['#!/bin/bash'] sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', []) sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False): if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}'] sh_lines += [f'source activate {venv}']
final_line = f'{python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'
# Use the original config:experiment string to avoid verbosity
final_line = f'{python_script} {original_config_exp_string} -t $SLURM_ARRAY_TASK_ID'
if self.consume(slurmC, 'python_exec', False): if self.consume(slurmC, 'python_exec', False):
final_line = f'./omni_sif_python {final_line}' final_line = f'./omni_sif_python {final_line}'
else: else:
final_line = f'python3 {final_line}' final_line = f'python3 {final_line}'
if self.consume(slurmC, 'xvfb', False): if self.consume(slurmC, 'xvfb', False):
final_line = f'xvfb-run {final_line}' final_line = f'xvfb-run {final_line}'
sh_lines += [final_line] sh_lines.append(final_line)
script = "\n".join(sh_lines) script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC, num_conv_versions) num_jobs = self._calc_num_jobs(all_configs[0].get('scheduler', {}), len(all_configs))
last_job_idx = num_jobs - 1 last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
array = f'0-{last_job_idx}%{num_parallel_jobs}' array = f'0-{last_job_idx}%{num_parallel_jobs}'
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
self.consume(config, 'name', '') if self.verify:
self.consume(config, 'project', '') input(f'[!] Press Enter to submit the job to SLURM.')
self.consume(config, 'vars', '')
self.consume(config, 'grid', '')
self.consume(config, 'ablative', '')
if config != {}:
print('[!] Unconsumed Config Parts:')
pprint(config)
if self.verify or config != {}:
input(f'<Press enter to submit {num_jobs} job(s) to slurm>')
job_id = job.submit() job_id = job.submit()
print(f'[>] Job submitted to slurm with id {job_id}') print(f'[>] Job submitted to SLURM with id {job_id}')
with open('job_hist.log', 'a') as f:
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
def _fork_processes(self, config, rep_ids): # Log file entry optimization
schedC = self.consume(config, 'scheduler', {}) with open('job_hist.log', 'a') as f:
config_logs = {}
for config, exp in config_exp_pairs:
if config not in config_logs:
config_logs[config] = []
config_logs[config].append(exp)
for config, exps in config_logs.items():
exps_str = ",".join(exps)
f.write(f'{config}:{exps_str} submitted to SLURM with id {job_id}\n')
def _fork_processes(self, configs, rep_ids):
"""
Fork processes to run all expanded configurations concurrently.
Parameters:
configs (list): A list of expanded configurations.
rep_ids (list): A list of repetition identifiers for the configurations.
"""
schedC = self.consume(configs[0], 'scheduler', {})
agents_per_job = self.consume(schedC, 'agents_per_job', 1) agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
@ -269,45 +295,39 @@ class Slate():
if num_p == 1: if num_p == 1:
print('[i] Running within main thread') print('[i] Running within main thread')
self._run_process(config, rep_ids=rep_ids, p_ind=0) self._run_process(configs, rep_ids=rep_ids, p_ind=0)
return return
procs = [] procs = []
reps_done = 0 reps_done = 0
for p in range(num_p): for p in range(num_p):
print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})') print(f'[i] Spawning separate thread/process ({p+1}/{num_p})')
num_reps = min(node_reps - reps_done, reps_per_agent) num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done + num_reps))]
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) proc_configs = [configs[i % len(configs)] for i in proc_rep_ids] # Distribute configs across processes
proc = Parallelization_Primitive(target=partial(self._run_process, proc_configs, proc_rep_ids, p_ind=p))
proc.start() proc.start()
procs.append(proc) procs.append(proc)
reps_done += num_reps reps_done += num_reps
for proc in procs: for proc in procs:
proc.join() proc.join()
print(f'[i] All threads/processes have terminated') print('[i] All threads/processes have terminated')
def _init_sweep(self, config): def _run_process(self, orig_configs, rep_ids, p_ind):
if self.sweep_id == None and self.consume(config, 'sweep.enable', False): """
sweepC = self.consume(config, 'sweep') Run a single process for a subset of configurations.
wandbC = copy.deepcopy(config['wandb'])
project = self.consume(wandbC, 'project')
self.sweep_id = wandb.sweep( Parameters:
sweep=sweepC, configs (list): A list of configurations to run.
project=project rep_ids (list): A list of repetition identifiers for the configurations.
) p_ind (int): Process index.
"""
def _run_process(self, orig_config, rep_ids, p_ind): for r in rep_ids:
config = copy.deepcopy(orig_config) self.run_id = r
if self.consume(config, 'sweep.enable', False): config = orig_configs[r % len(orig_configs)]
wandbC = copy.deepcopy(config['wandb']) self._run_single(config, [r], p_ind=p_ind)
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else:
self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind)
def _run_single(self, orig_config, rep_ids, p_ind): def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}') print(f'[P{p_ind}] I will work on reps {rep_ids}')
@ -343,10 +363,10 @@ class Slate():
except wandb.errors.CommError as e: except wandb.errors.CommError as e:
retry -= 1 retry -= 1
if retry: if retry:
print('Catched CommErr; retrying...') print('Caught CommErr; retrying...')
time.sleep(int(60*random.random())) time.sleep(int(60*random.random()))
else: else:
print('Catched CommErr; not retrying') print('Caught CommErr; not retrying')
raise e raise e
else: else:
retry = 0 retry = 0
@ -359,36 +379,6 @@ class Slate():
print(msg) print(msg)
orig_config = {} orig_config = {}
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
Runner = self.runners[runnerName]
if orig_config.consume('scheduler.bind_agent_to_core', False):
os.sched_setaffinity(0, [p_ind % os.cpu_count()])
with wandb.init(
project=project,
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
**wandbC
) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
run.config = copy.deepcopy(config)
runner = Runner(self, config)
runner.setup(wandbC['group']+wandbC['job_type'])
runner.run(run)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = {}
def _make_configs_for_runs(self, config): def _make_configs_for_runs(self, config):
c = copy.deepcopy(config) c = copy.deepcopy(config)
@ -425,12 +415,10 @@ class Slate():
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("config_file", nargs='?', default=None) parser.add_argument("config_experiments", nargs='+', help="List of config:experiment pairs")
parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-t", "--task_id", default=None, type=int) parser.add_argument("-t", "--task_id", default=None, type=int)
parser.add_argument("--sweep_id", default=None, type=str)
parser.add_argument("--ask_verify", action="store_true") parser.add_argument("--ask_verify", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@ -439,16 +427,21 @@ class Slate():
print(f'[i] Running on version [git:{self.get_version()}]') print(f'[i] Running on version [git:{self.get_version()}]')
if args.worker: if args.worker:
raise Exception('Not yet implemented') raise Exception('Worker mode not yet implemented')
assert args.config_file != None, 'Need to supply config file.' config_exp_pairs = []
for config_exp in args.config_experiments:
config, exps = config_exp.split(":")
exp_list = exps.split(",")
for exp in exp_list:
config_exp_pairs.append((config, exp))
if args.slurm: if args.slurm:
if args.ask_verify: if args.ask_verify:
self.verify = True self.verify = True
self.run_slurm(args.config_file, args.experiment) self.run_slurm(args.config_experiments, config_exp_pairs)
else: else:
self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id) self.run_local(config_exp_pairs, args.task_id)
def params_combine(config: dict, key: str, iter_func): def params_combine(config: dict, key: str, iter_func):
@ -456,21 +449,13 @@ def params_combine(config: dict, key: str, iter_func):
return [config] return [config]
combined_configs = [] combined_configs = []
# convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
# value is the list of values
tuple_dict = flatten_dict_to_tuple_keys(config[key]) tuple_dict = flatten_dict_to_tuple_keys(config[key])
_param_names = ['.'.join(t) for t in tuple_dict] _param_names = ['.'.join(t) for t in tuple_dict]
param_lengths = map(len, tuple_dict.values())
# create a new config for each parameter setting
for values in iter_func(*tuple_dict.values()): for values in iter_func(*tuple_dict.values()):
_config = copy.deepcopy(config) _config = copy.deepcopy(config)
# Remove Grid/List Argument
del _config[key] del _config[key]
# Expand Grid/List Parameters
for i, t in enumerate(tuple_dict.keys()): for i, t in enumerate(tuple_dict.keys()):
insert_deep_dictionary(d=_config, t=t, value=values[i]) insert_deep_dictionary(d=_config, t=t, value=values[i])
@ -488,11 +473,7 @@ def ablative_expand(conf_list):
for i, key in enumerate(tuple_dict): for i, key in enumerate(tuple_dict):
for val in tuple_dict[key]: for val in tuple_dict[key]:
_config = copy.deepcopy(config) _config = copy.deepcopy(config)
insert_deep_dictionary(_config, key, val)
insert_deep_dictionary(
_config, key, val
)
_config = extend_config_name(_config, [_param_names[i]], [val]) _config = extend_config_name(_config, [_param_names[i]], [val])
combined_configs.append(_config) combined_configs.append(_config)
return combined_configs return combined_configs
@ -504,18 +485,16 @@ def flatten_dict_to_tuple_keys(d: MutableMapping):
if isinstance(v, MutableMapping): if isinstance(v, MutableMapping):
sub_dict = flatten_dict_to_tuple_keys(v) sub_dict = flatten_dict_to_tuple_keys(v)
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()}) flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
elif isinstance(v, MutableSequence): elif isinstance(v, MutableSequence):
flat_dict[(k,)] = v flat_dict[(k,)] = v
return flat_dict return flat_dict
def insert_deep_dictionary(d: MutableMapping, t: tuple, value): def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
if type(t) is tuple: if type(t) is tuple:
if len(t) == 1: # tuple contains only one key if len(t) == 1:
d[t[0]] = value d[t[0]] = value
else: # tuple contains more than one key else:
if t[0] not in d: if t[0] not in d:
d[t[0]] = dict() d[t[0]] = dict()
insert_deep_dictionary(d[t[0]], t[1:], value) insert_deep_dictionary(d[t[0]], t[1:], value)
@ -525,11 +504,11 @@ def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
def append_deep_dictionary(d: MutableMapping, t: tuple, value): def append_deep_dictionary(d: MutableMapping, t: tuple, value):
if type(t) is tuple: if type(t) is tuple:
if len(t) == 1: # tuple contains only one key if len(t) is 1:
if t[0] not in d: if t[0] not in d:
d[t[0]] = [] d[t[0]] = []
d[t[0]].append(value) d[t[0]].append(value)
else: # tuple contains more than one key else:
if t[0] not in d: if t[0] not in d:
d[t[0]] = dict() d[t[0]] = dict()
append_deep_dictionary(d[t[0]], t[1:], value) append_deep_dictionary(d[t[0]], t[1:], value)
@ -539,15 +518,12 @@ def append_deep_dictionary(d: MutableMapping, t: tuple, value):
def extend_config_name(config: dict, param_names: list, values: list) -> dict: def extend_config_name(config: dict, param_names: list, values: list) -> dict:
_converted_name = convert_param_names(param_names, values) _converted_name = convert_param_names(param_names, values)
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
return config return config
def convert_param_names(_param_names: list, values: list) -> str: def convert_param_names(_param_names: list, values: list) -> str:
_converted_name = '_'.join("{}{}".format( _converted_name = '_'.join("{}{}".format(shorten_param(k), v) for k, v in zip(_param_names, values))
shorten_param(k), v) for k, v in zip(_param_names, values))
# _converted_name = re.sub("[' \[\],()]", '', _converted_name)
_converted_name = re.sub("[' ]", '', _converted_name) _converted_name = re.sub("[' ]", '', _converted_name)
_converted_name = re.sub('["]', '', _converted_name) _converted_name = re.sub('["]', '', _converted_name)
_converted_name = re.sub("[(\[]", '_', _converted_name) _converted_name = re.sub("[(\[]", '_', _converted_name)