Maybe better cli interface und ressource allocation
This commit is contained in:
parent
69c79bdb6f
commit
ebca76e680
262
slate/slate.py
262
slate/slate.py
@ -28,14 +28,10 @@ try:
|
||||
import pyslurm
|
||||
except ImportError:
|
||||
slurm_avaible = False
|
||||
print('[!] Slurm not avaible.')
|
||||
print('[!] Slurm not available.')
|
||||
else:
|
||||
slurm_avaible = True
|
||||
|
||||
# TODO: Implement Testing
|
||||
# TODO: Implement Ablative
|
||||
|
||||
|
||||
class Slate():
|
||||
def __init__(self, runners):
|
||||
self.runners = {
|
||||
@ -49,7 +45,6 @@ class Slate():
|
||||
self.task_id = None
|
||||
self.run_id = -1
|
||||
self._tmp_path = os.path.expandvars('$TMP')
|
||||
self.sweep_id = None
|
||||
self.verify = False
|
||||
|
||||
def load_config(self, filename, name):
|
||||
@ -109,7 +104,7 @@ class Slate():
|
||||
|
||||
def expand_vars(self, string, delta_desc='BASE', **kwargs):
|
||||
if isinstance(string, str):
|
||||
rand = int(random.random()*99999999)
|
||||
rand = int(random.random() * 99999999)
|
||||
if string == '{rand}':
|
||||
return rand
|
||||
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id)
|
||||
@ -141,11 +136,11 @@ class Slate():
|
||||
keys_arr = key.split('.')
|
||||
if len(keys_arr) == 1:
|
||||
k = keys_arr[0]
|
||||
if default != None:
|
||||
if default is not None:
|
||||
if isinstance(conf, Mapping):
|
||||
val = conf.get(k, default)
|
||||
else:
|
||||
if default != None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise Exception('')
|
||||
else:
|
||||
@ -173,7 +168,7 @@ class Slate():
|
||||
|
||||
def _calc_num_jobs(self, schedC, num_conv_versions):
|
||||
schedulerC = copy.deepcopy(schedC)
|
||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
|
||||
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
||||
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
||||
reps_per_job = reps_per_agent * agents_per_job
|
||||
@ -183,84 +178,115 @@ class Slate():
|
||||
def _reps_for_job(self, schedC, task_id, num_conv_versions):
|
||||
schedulerC = copy.deepcopy(schedC)
|
||||
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
|
||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
||||
if task_id == None:
|
||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
|
||||
if task_id is None:
|
||||
return list(range(0, reps))
|
||||
reps_for_job = [[] for i in range(num_jobs)]
|
||||
reps_for_job = [[] for _ in range(num_jobs)]
|
||||
for i in range(reps):
|
||||
reps_for_job[i % num_jobs].append(i)
|
||||
return reps_for_job[task_id]
|
||||
|
||||
def run_local(self, filename, name, task_id, sweep_id):
|
||||
def _make_configs_for_runs(self, config_exp_pairs):
|
||||
"""
|
||||
Expand configurations across all provided experiments, grid, and ablation variants.
|
||||
|
||||
Parameters:
|
||||
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
|
||||
|
||||
Returns:
|
||||
list: A list of expanded configurations ready for execution.
|
||||
"""
|
||||
all_expanded_configs = []
|
||||
|
||||
for config, exp in config_exp_pairs:
|
||||
config_data = self.load_config(config, exp)
|
||||
grid_versions = self._make_grid_versions(config_data)
|
||||
exp_variants = self._make_ablative_versions(config_data, grid_versions)
|
||||
all_expanded_configs.extend(exp_variants)
|
||||
|
||||
return all_expanded_configs
|
||||
|
||||
def run_local(self, config_exp_pairs, task_id):
|
||||
"""
|
||||
Run all expanded configurations locally, handling all variants and their repetitions concurrently.
|
||||
|
||||
Parameters:
|
||||
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
|
||||
task_id (int): The task ID for the experiments.
|
||||
"""
|
||||
self.task_id = task_id
|
||||
config = self.load_config(filename, name)
|
||||
num_conv_versions = self._get_num_conv_versions(config)
|
||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||
all_configs = self._make_configs_for_runs(config_exp_pairs)
|
||||
|
||||
num_conv_versions = len(all_configs)
|
||||
schedulerC = copy.deepcopy(all_configs[0].get('scheduler', {}))
|
||||
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
|
||||
self.sweep_id = sweep_id
|
||||
self._init_sweep(config)
|
||||
self._fork_processes(config, rep_ids)
|
||||
self._fork_processes(all_configs, rep_ids)
|
||||
|
||||
def run_slurm(self, filename, name):
|
||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
||||
config = self.load_config(filename, name)
|
||||
slurmC = self.consume(config, 'slurm', expand=True)
|
||||
schedC = self.consume(config, 'scheduler')
|
||||
def run_slurm(self, original_config_exp_string, config_exp_pairs):
|
||||
"""
|
||||
Schedule all expanded configurations on SLURM within a single job.
|
||||
|
||||
Parameters:
|
||||
original_config_exp_string (str): The original string of config:experiment pairs provided by the user.
|
||||
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
|
||||
"""
|
||||
all_configs = self._make_configs_for_runs(config_exp_pairs)
|
||||
|
||||
slurmC = self.consume(all_configs[0], 'slurm', expand=True)
|
||||
s_name = self.consume(slurmC, 'name')
|
||||
|
||||
num_conv_versions = self._get_num_conv_versions(config)
|
||||
|
||||
# Pre Validation
|
||||
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
|
||||
if self.consume(slurmC, 'pre_validate', True):
|
||||
Runner = self.runners[runnerName]
|
||||
runner = Runner(self, config)
|
||||
runner.setup('PreDeployment-Validation')
|
||||
|
||||
self._init_sweep(config)
|
||||
self.consume(config, 'wandb')
|
||||
|
||||
python_script = 'main.py'
|
||||
sh_lines = ['#!/bin/bash']
|
||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||
if venv := self.consume(slurmC, 'venv', False):
|
||||
sh_lines += [f'source activate {venv}']
|
||||
final_line = f'{python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'
|
||||
|
||||
# Use the original config:experiment string to avoid verbosity
|
||||
final_line = f'{python_script} {original_config_exp_string} -t $SLURM_ARRAY_TASK_ID'
|
||||
|
||||
if self.consume(slurmC, 'python_exec', False):
|
||||
final_line = f'./omni_sif_python {final_line}'
|
||||
else:
|
||||
final_line = f'python3 {final_line}'
|
||||
if self.consume(slurmC, 'xvfb', False):
|
||||
final_line = f'xvfb-run {final_line}'
|
||||
sh_lines += [final_line]
|
||||
sh_lines.append(final_line)
|
||||
|
||||
script = "\n".join(sh_lines)
|
||||
|
||||
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
|
||||
|
||||
num_jobs = self._calc_num_jobs(all_configs[0].get('scheduler', {}), len(all_configs))
|
||||
last_job_idx = num_jobs - 1
|
||||
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
||||
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
||||
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
|
||||
|
||||
self.consume(config, 'name', '')
|
||||
self.consume(config, 'project', '')
|
||||
self.consume(config, 'vars', '')
|
||||
if self.verify:
|
||||
input(f'[!] Press Enter to submit the job to SLURM.')
|
||||
|
||||
self.consume(config, 'grid', '')
|
||||
self.consume(config, 'ablative', '')
|
||||
|
||||
if config != {}:
|
||||
print('[!] Unconsumed Config Parts:')
|
||||
pprint(config)
|
||||
if self.verify or config != {}:
|
||||
input(f'<Press enter to submit {num_jobs} job(s) to slurm>')
|
||||
job_id = job.submit()
|
||||
print(f'[>] Job submitted to slurm with id {job_id}')
|
||||
with open('job_hist.log', 'a') as f:
|
||||
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
|
||||
print(f'[>] Job submitted to SLURM with id {job_id}')
|
||||
|
||||
def _fork_processes(self, config, rep_ids):
|
||||
schedC = self.consume(config, 'scheduler', {})
|
||||
# Log file entry optimization
|
||||
with open('job_hist.log', 'a') as f:
|
||||
config_logs = {}
|
||||
for config, exp in config_exp_pairs:
|
||||
if config not in config_logs:
|
||||
config_logs[config] = []
|
||||
config_logs[config].append(exp)
|
||||
|
||||
for config, exps in config_logs.items():
|
||||
exps_str = ",".join(exps)
|
||||
f.write(f'{config}:{exps_str} submitted to SLURM with id {job_id}\n')
|
||||
|
||||
def _fork_processes(self, configs, rep_ids):
|
||||
"""
|
||||
Fork processes to run all expanded configurations concurrently.
|
||||
|
||||
Parameters:
|
||||
configs (list): A list of expanded configurations.
|
||||
rep_ids (list): A list of repetition identifiers for the configurations.
|
||||
"""
|
||||
schedC = self.consume(configs[0], 'scheduler', {})
|
||||
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
||||
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
||||
|
||||
@ -269,45 +295,39 @@ class Slate():
|
||||
|
||||
if num_p == 1:
|
||||
print('[i] Running within main thread')
|
||||
self._run_process(config, rep_ids=rep_ids, p_ind=0)
|
||||
self._run_process(configs, rep_ids=rep_ids, p_ind=0)
|
||||
return
|
||||
|
||||
procs = []
|
||||
|
||||
reps_done = 0
|
||||
|
||||
for p in range(num_p):
|
||||
print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})')
|
||||
print(f'[i] Spawning separate thread/process ({p+1}/{num_p})')
|
||||
num_reps = min(node_reps - reps_done, reps_per_agent)
|
||||
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
||||
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
||||
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done + num_reps))]
|
||||
proc_configs = [configs[i % len(configs)] for i in proc_rep_ids] # Distribute configs across processes
|
||||
proc = Parallelization_Primitive(target=partial(self._run_process, proc_configs, proc_rep_ids, p_ind=p))
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
reps_done += num_reps
|
||||
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
print(f'[i] All threads/processes have terminated')
|
||||
print('[i] All threads/processes have terminated')
|
||||
|
||||
def _init_sweep(self, config):
|
||||
if self.sweep_id == None and self.consume(config, 'sweep.enable', False):
|
||||
sweepC = self.consume(config, 'sweep')
|
||||
wandbC = copy.deepcopy(config['wandb'])
|
||||
project = self.consume(wandbC, 'project')
|
||||
def _run_process(self, orig_configs, rep_ids, p_ind):
|
||||
"""
|
||||
Run a single process for a subset of configurations.
|
||||
|
||||
self.sweep_id = wandb.sweep(
|
||||
sweep=sweepC,
|
||||
project=project
|
||||
)
|
||||
|
||||
def _run_process(self, orig_config, rep_ids, p_ind):
|
||||
config = copy.deepcopy(orig_config)
|
||||
if self.consume(config, 'sweep.enable', False):
|
||||
wandbC = copy.deepcopy(config['wandb'])
|
||||
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
||||
else:
|
||||
self.consume(config, 'sweep', {})
|
||||
self._run_single(config, rep_ids, p_ind=p_ind)
|
||||
Parameters:
|
||||
configs (list): A list of configurations to run.
|
||||
rep_ids (list): A list of repetition identifiers for the configurations.
|
||||
p_ind (int): Process index.
|
||||
"""
|
||||
for r in rep_ids:
|
||||
self.run_id = r
|
||||
config = orig_configs[r % len(orig_configs)]
|
||||
self._run_single(config, [r], p_ind=p_ind)
|
||||
|
||||
def _run_single(self, orig_config, rep_ids, p_ind):
|
||||
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
||||
@ -343,10 +363,10 @@ class Slate():
|
||||
except wandb.errors.CommError as e:
|
||||
retry -= 1
|
||||
if retry:
|
||||
print('Catched CommErr; retrying...')
|
||||
print('Caught CommErr; retrying...')
|
||||
time.sleep(int(60*random.random()))
|
||||
else:
|
||||
print('Catched CommErr; not retrying')
|
||||
print('Caught CommErr; not retrying')
|
||||
raise e
|
||||
else:
|
||||
retry = 0
|
||||
@ -359,36 +379,6 @@ class Slate():
|
||||
print(msg)
|
||||
orig_config = {}
|
||||
|
||||
def _run_from_sweep(self, orig_config, p_ind):
|
||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||
project = self.consume(wandbC, 'project')
|
||||
|
||||
Runner = self.runners[runnerName]
|
||||
|
||||
if orig_config.consume('scheduler.bind_agent_to_core', False):
|
||||
os.sched_setaffinity(0, [p_ind % os.cpu_count()])
|
||||
|
||||
with wandb.init(
|
||||
project=project,
|
||||
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
||||
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
|
||||
**wandbC
|
||||
) as run:
|
||||
config = copy.deepcopy(orig_config)
|
||||
self.deep_update(config, wandb.config)
|
||||
run.config = copy.deepcopy(config)
|
||||
runner = Runner(self, config)
|
||||
runner.setup(wandbC['group']+wandbC['job_type'])
|
||||
runner.run(run)
|
||||
|
||||
if config != {}:
|
||||
msg = ('Config was not completely consumed: ', config)
|
||||
if REQUIRE_CONFIG_CONSUMED:
|
||||
raise Exception(msg)
|
||||
else:
|
||||
print(msg)
|
||||
orig_config = {}
|
||||
|
||||
def _make_configs_for_runs(self, config):
|
||||
c = copy.deepcopy(config)
|
||||
|
||||
@ -425,12 +415,10 @@ class Slate():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_file", nargs='?', default=None)
|
||||
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
||||
parser.add_argument("config_experiments", nargs='+', help="List of config:experiment pairs")
|
||||
parser.add_argument("-s", "--slurm", action="store_true")
|
||||
parser.add_argument("-w", "--worker", action="store_true")
|
||||
parser.add_argument("-t", "--task_id", default=None, type=int)
|
||||
parser.add_argument("--sweep_id", default=None, type=str)
|
||||
parser.add_argument("--ask_verify", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -439,16 +427,21 @@ class Slate():
|
||||
print(f'[i] Running on version [git:{self.get_version()}]')
|
||||
|
||||
if args.worker:
|
||||
raise Exception('Not yet implemented')
|
||||
raise Exception('Worker mode not yet implemented')
|
||||
|
||||
assert args.config_file != None, 'Need to supply config file.'
|
||||
config_exp_pairs = []
|
||||
for config_exp in args.config_experiments:
|
||||
config, exps = config_exp.split(":")
|
||||
exp_list = exps.split(",")
|
||||
for exp in exp_list:
|
||||
config_exp_pairs.append((config, exp))
|
||||
|
||||
if args.slurm:
|
||||
if args.ask_verify:
|
||||
self.verify = True
|
||||
self.run_slurm(args.config_file, args.experiment)
|
||||
self.run_slurm(args.config_experiments, config_exp_pairs)
|
||||
else:
|
||||
self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id)
|
||||
self.run_local(config_exp_pairs, args.task_id)
|
||||
|
||||
|
||||
def params_combine(config: dict, key: str, iter_func):
|
||||
@ -456,21 +449,13 @@ def params_combine(config: dict, key: str, iter_func):
|
||||
return [config]
|
||||
|
||||
combined_configs = []
|
||||
# convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
|
||||
# value is the list of values
|
||||
tuple_dict = flatten_dict_to_tuple_keys(config[key])
|
||||
_param_names = ['.'.join(t) for t in tuple_dict]
|
||||
|
||||
param_lengths = map(len, tuple_dict.values())
|
||||
|
||||
# create a new config for each parameter setting
|
||||
for values in iter_func(*tuple_dict.values()):
|
||||
_config = copy.deepcopy(config)
|
||||
|
||||
# Remove Grid/List Argument
|
||||
del _config[key]
|
||||
|
||||
# Expand Grid/List Parameters
|
||||
for i, t in enumerate(tuple_dict.keys()):
|
||||
insert_deep_dictionary(d=_config, t=t, value=values[i])
|
||||
|
||||
@ -488,11 +473,7 @@ def ablative_expand(conf_list):
|
||||
for i, key in enumerate(tuple_dict):
|
||||
for val in tuple_dict[key]:
|
||||
_config = copy.deepcopy(config)
|
||||
|
||||
insert_deep_dictionary(
|
||||
_config, key, val
|
||||
)
|
||||
|
||||
insert_deep_dictionary(_config, key, val)
|
||||
_config = extend_config_name(_config, [_param_names[i]], [val])
|
||||
combined_configs.append(_config)
|
||||
return combined_configs
|
||||
@ -504,18 +485,16 @@ def flatten_dict_to_tuple_keys(d: MutableMapping):
|
||||
if isinstance(v, MutableMapping):
|
||||
sub_dict = flatten_dict_to_tuple_keys(v)
|
||||
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
|
||||
|
||||
elif isinstance(v, MutableSequence):
|
||||
flat_dict[(k,)] = v
|
||||
|
||||
return flat_dict
|
||||
|
||||
|
||||
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||
if type(t) is tuple:
|
||||
if len(t) == 1: # tuple contains only one key
|
||||
if len(t) == 1:
|
||||
d[t[0]] = value
|
||||
else: # tuple contains more than one key
|
||||
else:
|
||||
if t[0] not in d:
|
||||
d[t[0]] = dict()
|
||||
insert_deep_dictionary(d[t[0]], t[1:], value)
|
||||
@ -525,11 +504,11 @@ def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||
|
||||
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||
if type(t) is tuple:
|
||||
if len(t) == 1: # tuple contains only one key
|
||||
if len(t) is 1:
|
||||
if t[0] not in d:
|
||||
d[t[0]] = []
|
||||
d[t[0]].append(value)
|
||||
else: # tuple contains more than one key
|
||||
else:
|
||||
if t[0] not in d:
|
||||
d[t[0]] = dict()
|
||||
append_deep_dictionary(d[t[0]], t[1:], value)
|
||||
@ -539,15 +518,12 @@ def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||
|
||||
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
|
||||
_converted_name = convert_param_names(param_names, values)
|
||||
|
||||
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
|
||||
return config
|
||||
|
||||
|
||||
def convert_param_names(_param_names: list, values: list) -> str:
|
||||
_converted_name = '_'.join("{}{}".format(
|
||||
shorten_param(k), v) for k, v in zip(_param_names, values))
|
||||
# _converted_name = re.sub("[' \[\],()]", '', _converted_name)
|
||||
_converted_name = '_'.join("{}{}".format(shorten_param(k), v) for k, v in zip(_param_names, values))
|
||||
_converted_name = re.sub("[' ]", '', _converted_name)
|
||||
_converted_name = re.sub('["]', '', _converted_name)
|
||||
_converted_name = re.sub("[(\[]", '_', _converted_name)
|
||||
|
Loading…
Reference in New Issue
Block a user