Maybe better cli interface und ressource allocation
This commit is contained in:
parent
69c79bdb6f
commit
ebca76e680
262
slate/slate.py
262
slate/slate.py
@ -28,14 +28,10 @@ try:
|
|||||||
import pyslurm
|
import pyslurm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
slurm_avaible = False
|
slurm_avaible = False
|
||||||
print('[!] Slurm not avaible.')
|
print('[!] Slurm not available.')
|
||||||
else:
|
else:
|
||||||
slurm_avaible = True
|
slurm_avaible = True
|
||||||
|
|
||||||
# TODO: Implement Testing
|
|
||||||
# TODO: Implement Ablative
|
|
||||||
|
|
||||||
|
|
||||||
class Slate():
|
class Slate():
|
||||||
def __init__(self, runners):
|
def __init__(self, runners):
|
||||||
self.runners = {
|
self.runners = {
|
||||||
@ -49,7 +45,6 @@ class Slate():
|
|||||||
self.task_id = None
|
self.task_id = None
|
||||||
self.run_id = -1
|
self.run_id = -1
|
||||||
self._tmp_path = os.path.expandvars('$TMP')
|
self._tmp_path = os.path.expandvars('$TMP')
|
||||||
self.sweep_id = None
|
|
||||||
self.verify = False
|
self.verify = False
|
||||||
|
|
||||||
def load_config(self, filename, name):
|
def load_config(self, filename, name):
|
||||||
@ -109,7 +104,7 @@ class Slate():
|
|||||||
|
|
||||||
def expand_vars(self, string, delta_desc='BASE', **kwargs):
|
def expand_vars(self, string, delta_desc='BASE', **kwargs):
|
||||||
if isinstance(string, str):
|
if isinstance(string, str):
|
||||||
rand = int(random.random()*99999999)
|
rand = int(random.random() * 99999999)
|
||||||
if string == '{rand}':
|
if string == '{rand}':
|
||||||
return rand
|
return rand
|
||||||
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id)
|
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id)
|
||||||
@ -141,11 +136,11 @@ class Slate():
|
|||||||
keys_arr = key.split('.')
|
keys_arr = key.split('.')
|
||||||
if len(keys_arr) == 1:
|
if len(keys_arr) == 1:
|
||||||
k = keys_arr[0]
|
k = keys_arr[0]
|
||||||
if default != None:
|
if default is not None:
|
||||||
if isinstance(conf, Mapping):
|
if isinstance(conf, Mapping):
|
||||||
val = conf.get(k, default)
|
val = conf.get(k, default)
|
||||||
else:
|
else:
|
||||||
if default != None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
raise Exception('')
|
raise Exception('')
|
||||||
else:
|
else:
|
||||||
@ -173,7 +168,7 @@ class Slate():
|
|||||||
|
|
||||||
def _calc_num_jobs(self, schedC, num_conv_versions):
|
def _calc_num_jobs(self, schedC, num_conv_versions):
|
||||||
schedulerC = copy.deepcopy(schedC)
|
schedulerC = copy.deepcopy(schedC)
|
||||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
|
||||||
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
||||||
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
||||||
reps_per_job = reps_per_agent * agents_per_job
|
reps_per_job = reps_per_agent * agents_per_job
|
||||||
@ -183,84 +178,115 @@ class Slate():
|
|||||||
def _reps_for_job(self, schedC, task_id, num_conv_versions):
|
def _reps_for_job(self, schedC, task_id, num_conv_versions):
|
||||||
schedulerC = copy.deepcopy(schedC)
|
schedulerC = copy.deepcopy(schedC)
|
||||||
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
|
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
|
||||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1) * num_conv_versions)
|
||||||
if task_id == None:
|
if task_id is None:
|
||||||
return list(range(0, reps))
|
return list(range(0, reps))
|
||||||
reps_for_job = [[] for i in range(num_jobs)]
|
reps_for_job = [[] for _ in range(num_jobs)]
|
||||||
for i in range(reps):
|
for i in range(reps):
|
||||||
reps_for_job[i % num_jobs].append(i)
|
reps_for_job[i % num_jobs].append(i)
|
||||||
return reps_for_job[task_id]
|
return reps_for_job[task_id]
|
||||||
|
|
||||||
def run_local(self, filename, name, task_id, sweep_id):
|
def _make_configs_for_runs(self, config_exp_pairs):
|
||||||
|
"""
|
||||||
|
Expand configurations across all provided experiments, grid, and ablation variants.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of expanded configurations ready for execution.
|
||||||
|
"""
|
||||||
|
all_expanded_configs = []
|
||||||
|
|
||||||
|
for config, exp in config_exp_pairs:
|
||||||
|
config_data = self.load_config(config, exp)
|
||||||
|
grid_versions = self._make_grid_versions(config_data)
|
||||||
|
exp_variants = self._make_ablative_versions(config_data, grid_versions)
|
||||||
|
all_expanded_configs.extend(exp_variants)
|
||||||
|
|
||||||
|
return all_expanded_configs
|
||||||
|
|
||||||
|
def run_local(self, config_exp_pairs, task_id):
|
||||||
|
"""
|
||||||
|
Run all expanded configurations locally, handling all variants and their repetitions concurrently.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
|
||||||
|
task_id (int): The task ID for the experiments.
|
||||||
|
"""
|
||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
config = self.load_config(filename, name)
|
all_configs = self._make_configs_for_runs(config_exp_pairs)
|
||||||
num_conv_versions = self._get_num_conv_versions(config)
|
|
||||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
num_conv_versions = len(all_configs)
|
||||||
|
schedulerC = copy.deepcopy(all_configs[0].get('scheduler', {}))
|
||||||
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
|
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
|
||||||
self.sweep_id = sweep_id
|
self._fork_processes(all_configs, rep_ids)
|
||||||
self._init_sweep(config)
|
|
||||||
self._fork_processes(config, rep_ids)
|
|
||||||
|
|
||||||
def run_slurm(self, filename, name):
|
def run_slurm(self, original_config_exp_string, config_exp_pairs):
|
||||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
"""
|
||||||
config = self.load_config(filename, name)
|
Schedule all expanded configurations on SLURM within a single job.
|
||||||
slurmC = self.consume(config, 'slurm', expand=True)
|
|
||||||
schedC = self.consume(config, 'scheduler')
|
Parameters:
|
||||||
|
original_config_exp_string (str): The original string of config:experiment pairs provided by the user.
|
||||||
|
config_exp_pairs (list): A list of tuples where each tuple contains (filename, experiment name).
|
||||||
|
"""
|
||||||
|
all_configs = self._make_configs_for_runs(config_exp_pairs)
|
||||||
|
|
||||||
|
slurmC = self.consume(all_configs[0], 'slurm', expand=True)
|
||||||
s_name = self.consume(slurmC, 'name')
|
s_name = self.consume(slurmC, 'name')
|
||||||
|
|
||||||
num_conv_versions = self._get_num_conv_versions(config)
|
|
||||||
|
|
||||||
# Pre Validation
|
|
||||||
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
|
|
||||||
if self.consume(slurmC, 'pre_validate', True):
|
|
||||||
Runner = self.runners[runnerName]
|
|
||||||
runner = Runner(self, config)
|
|
||||||
runner.setup('PreDeployment-Validation')
|
|
||||||
|
|
||||||
self._init_sweep(config)
|
|
||||||
self.consume(config, 'wandb')
|
|
||||||
|
|
||||||
python_script = 'main.py'
|
python_script = 'main.py'
|
||||||
sh_lines = ['#!/bin/bash']
|
sh_lines = ['#!/bin/bash']
|
||||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||||
if venv := self.consume(slurmC, 'venv', False):
|
if venv := self.consume(slurmC, 'venv', False):
|
||||||
sh_lines += [f'source activate {venv}']
|
sh_lines += [f'source activate {venv}']
|
||||||
final_line = f'{python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'
|
|
||||||
|
# Use the original config:experiment string to avoid verbosity
|
||||||
|
final_line = f'{python_script} {original_config_exp_string} -t $SLURM_ARRAY_TASK_ID'
|
||||||
|
|
||||||
if self.consume(slurmC, 'python_exec', False):
|
if self.consume(slurmC, 'python_exec', False):
|
||||||
final_line = f'./omni_sif_python {final_line}'
|
final_line = f'./omni_sif_python {final_line}'
|
||||||
else:
|
else:
|
||||||
final_line = f'python3 {final_line}'
|
final_line = f'python3 {final_line}'
|
||||||
if self.consume(slurmC, 'xvfb', False):
|
if self.consume(slurmC, 'xvfb', False):
|
||||||
final_line = f'xvfb-run {final_line}'
|
final_line = f'xvfb-run {final_line}'
|
||||||
sh_lines += [final_line]
|
sh_lines.append(final_line)
|
||||||
|
|
||||||
script = "\n".join(sh_lines)
|
script = "\n".join(sh_lines)
|
||||||
|
|
||||||
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
|
num_jobs = self._calc_num_jobs(all_configs[0].get('scheduler', {}), len(all_configs))
|
||||||
|
|
||||||
last_job_idx = num_jobs - 1
|
last_job_idx = num_jobs - 1
|
||||||
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
||||||
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
||||||
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
|
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
|
||||||
|
|
||||||
self.consume(config, 'name', '')
|
if self.verify:
|
||||||
self.consume(config, 'project', '')
|
input(f'[!] Press Enter to submit the job to SLURM.')
|
||||||
self.consume(config, 'vars', '')
|
|
||||||
|
|
||||||
self.consume(config, 'grid', '')
|
|
||||||
self.consume(config, 'ablative', '')
|
|
||||||
|
|
||||||
if config != {}:
|
|
||||||
print('[!] Unconsumed Config Parts:')
|
|
||||||
pprint(config)
|
|
||||||
if self.verify or config != {}:
|
|
||||||
input(f'<Press enter to submit {num_jobs} job(s) to slurm>')
|
|
||||||
job_id = job.submit()
|
job_id = job.submit()
|
||||||
print(f'[>] Job submitted to slurm with id {job_id}')
|
print(f'[>] Job submitted to SLURM with id {job_id}')
|
||||||
with open('job_hist.log', 'a') as f:
|
|
||||||
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
|
|
||||||
|
|
||||||
def _fork_processes(self, config, rep_ids):
|
# Log file entry optimization
|
||||||
schedC = self.consume(config, 'scheduler', {})
|
with open('job_hist.log', 'a') as f:
|
||||||
|
config_logs = {}
|
||||||
|
for config, exp in config_exp_pairs:
|
||||||
|
if config not in config_logs:
|
||||||
|
config_logs[config] = []
|
||||||
|
config_logs[config].append(exp)
|
||||||
|
|
||||||
|
for config, exps in config_logs.items():
|
||||||
|
exps_str = ",".join(exps)
|
||||||
|
f.write(f'{config}:{exps_str} submitted to SLURM with id {job_id}\n')
|
||||||
|
|
||||||
|
def _fork_processes(self, configs, rep_ids):
|
||||||
|
"""
|
||||||
|
Fork processes to run all expanded configurations concurrently.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
configs (list): A list of expanded configurations.
|
||||||
|
rep_ids (list): A list of repetition identifiers for the configurations.
|
||||||
|
"""
|
||||||
|
schedC = self.consume(configs[0], 'scheduler', {})
|
||||||
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
||||||
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
||||||
|
|
||||||
@ -269,45 +295,39 @@ class Slate():
|
|||||||
|
|
||||||
if num_p == 1:
|
if num_p == 1:
|
||||||
print('[i] Running within main thread')
|
print('[i] Running within main thread')
|
||||||
self._run_process(config, rep_ids=rep_ids, p_ind=0)
|
self._run_process(configs, rep_ids=rep_ids, p_ind=0)
|
||||||
return
|
return
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
|
|
||||||
reps_done = 0
|
reps_done = 0
|
||||||
|
|
||||||
for p in range(num_p):
|
for p in range(num_p):
|
||||||
print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})')
|
print(f'[i] Spawning separate thread/process ({p+1}/{num_p})')
|
||||||
num_reps = min(node_reps - reps_done, reps_per_agent)
|
num_reps = min(node_reps - reps_done, reps_per_agent)
|
||||||
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done + num_reps))]
|
||||||
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
proc_configs = [configs[i % len(configs)] for i in proc_rep_ids] # Distribute configs across processes
|
||||||
|
proc = Parallelization_Primitive(target=partial(self._run_process, proc_configs, proc_rep_ids, p_ind=p))
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
reps_done += num_reps
|
reps_done += num_reps
|
||||||
|
|
||||||
for proc in procs:
|
for proc in procs:
|
||||||
proc.join()
|
proc.join()
|
||||||
print(f'[i] All threads/processes have terminated')
|
print('[i] All threads/processes have terminated')
|
||||||
|
|
||||||
def _init_sweep(self, config):
|
def _run_process(self, orig_configs, rep_ids, p_ind):
|
||||||
if self.sweep_id == None and self.consume(config, 'sweep.enable', False):
|
"""
|
||||||
sweepC = self.consume(config, 'sweep')
|
Run a single process for a subset of configurations.
|
||||||
wandbC = copy.deepcopy(config['wandb'])
|
|
||||||
project = self.consume(wandbC, 'project')
|
|
||||||
|
|
||||||
self.sweep_id = wandb.sweep(
|
Parameters:
|
||||||
sweep=sweepC,
|
configs (list): A list of configurations to run.
|
||||||
project=project
|
rep_ids (list): A list of repetition identifiers for the configurations.
|
||||||
)
|
p_ind (int): Process index.
|
||||||
|
"""
|
||||||
def _run_process(self, orig_config, rep_ids, p_ind):
|
for r in rep_ids:
|
||||||
config = copy.deepcopy(orig_config)
|
self.run_id = r
|
||||||
if self.consume(config, 'sweep.enable', False):
|
config = orig_configs[r % len(orig_configs)]
|
||||||
wandbC = copy.deepcopy(config['wandb'])
|
self._run_single(config, [r], p_ind=p_ind)
|
||||||
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
|
||||||
else:
|
|
||||||
self.consume(config, 'sweep', {})
|
|
||||||
self._run_single(config, rep_ids, p_ind=p_ind)
|
|
||||||
|
|
||||||
def _run_single(self, orig_config, rep_ids, p_ind):
|
def _run_single(self, orig_config, rep_ids, p_ind):
|
||||||
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
||||||
@ -343,10 +363,10 @@ class Slate():
|
|||||||
except wandb.errors.CommError as e:
|
except wandb.errors.CommError as e:
|
||||||
retry -= 1
|
retry -= 1
|
||||||
if retry:
|
if retry:
|
||||||
print('Catched CommErr; retrying...')
|
print('Caught CommErr; retrying...')
|
||||||
time.sleep(int(60*random.random()))
|
time.sleep(int(60*random.random()))
|
||||||
else:
|
else:
|
||||||
print('Catched CommErr; not retrying')
|
print('Caught CommErr; not retrying')
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
retry = 0
|
retry = 0
|
||||||
@ -359,36 +379,6 @@ class Slate():
|
|||||||
print(msg)
|
print(msg)
|
||||||
orig_config = {}
|
orig_config = {}
|
||||||
|
|
||||||
def _run_from_sweep(self, orig_config, p_ind):
|
|
||||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
|
||||||
project = self.consume(wandbC, 'project')
|
|
||||||
|
|
||||||
Runner = self.runners[runnerName]
|
|
||||||
|
|
||||||
if orig_config.consume('scheduler.bind_agent_to_core', False):
|
|
||||||
os.sched_setaffinity(0, [p_ind % os.cpu_count()])
|
|
||||||
|
|
||||||
with wandb.init(
|
|
||||||
project=project,
|
|
||||||
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
|
||||||
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
|
|
||||||
**wandbC
|
|
||||||
) as run:
|
|
||||||
config = copy.deepcopy(orig_config)
|
|
||||||
self.deep_update(config, wandb.config)
|
|
||||||
run.config = copy.deepcopy(config)
|
|
||||||
runner = Runner(self, config)
|
|
||||||
runner.setup(wandbC['group']+wandbC['job_type'])
|
|
||||||
runner.run(run)
|
|
||||||
|
|
||||||
if config != {}:
|
|
||||||
msg = ('Config was not completely consumed: ', config)
|
|
||||||
if REQUIRE_CONFIG_CONSUMED:
|
|
||||||
raise Exception(msg)
|
|
||||||
else:
|
|
||||||
print(msg)
|
|
||||||
orig_config = {}
|
|
||||||
|
|
||||||
def _make_configs_for_runs(self, config):
|
def _make_configs_for_runs(self, config):
|
||||||
c = copy.deepcopy(config)
|
c = copy.deepcopy(config)
|
||||||
|
|
||||||
@ -425,12 +415,10 @@ class Slate():
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("config_file", nargs='?', default=None)
|
parser.add_argument("config_experiments", nargs='+', help="List of config:experiment pairs")
|
||||||
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
|
||||||
parser.add_argument("-s", "--slurm", action="store_true")
|
parser.add_argument("-s", "--slurm", action="store_true")
|
||||||
parser.add_argument("-w", "--worker", action="store_true")
|
parser.add_argument("-w", "--worker", action="store_true")
|
||||||
parser.add_argument("-t", "--task_id", default=None, type=int)
|
parser.add_argument("-t", "--task_id", default=None, type=int)
|
||||||
parser.add_argument("--sweep_id", default=None, type=str)
|
|
||||||
parser.add_argument("--ask_verify", action="store_true")
|
parser.add_argument("--ask_verify", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -439,16 +427,21 @@ class Slate():
|
|||||||
print(f'[i] Running on version [git:{self.get_version()}]')
|
print(f'[i] Running on version [git:{self.get_version()}]')
|
||||||
|
|
||||||
if args.worker:
|
if args.worker:
|
||||||
raise Exception('Not yet implemented')
|
raise Exception('Worker mode not yet implemented')
|
||||||
|
|
||||||
assert args.config_file != None, 'Need to supply config file.'
|
config_exp_pairs = []
|
||||||
|
for config_exp in args.config_experiments:
|
||||||
|
config, exps = config_exp.split(":")
|
||||||
|
exp_list = exps.split(",")
|
||||||
|
for exp in exp_list:
|
||||||
|
config_exp_pairs.append((config, exp))
|
||||||
|
|
||||||
if args.slurm:
|
if args.slurm:
|
||||||
if args.ask_verify:
|
if args.ask_verify:
|
||||||
self.verify = True
|
self.verify = True
|
||||||
self.run_slurm(args.config_file, args.experiment)
|
self.run_slurm(args.config_experiments, config_exp_pairs)
|
||||||
else:
|
else:
|
||||||
self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id)
|
self.run_local(config_exp_pairs, args.task_id)
|
||||||
|
|
||||||
|
|
||||||
def params_combine(config: dict, key: str, iter_func):
|
def params_combine(config: dict, key: str, iter_func):
|
||||||
@ -456,21 +449,13 @@ def params_combine(config: dict, key: str, iter_func):
|
|||||||
return [config]
|
return [config]
|
||||||
|
|
||||||
combined_configs = []
|
combined_configs = []
|
||||||
# convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
|
|
||||||
# value is the list of values
|
|
||||||
tuple_dict = flatten_dict_to_tuple_keys(config[key])
|
tuple_dict = flatten_dict_to_tuple_keys(config[key])
|
||||||
_param_names = ['.'.join(t) for t in tuple_dict]
|
_param_names = ['.'.join(t) for t in tuple_dict]
|
||||||
|
|
||||||
param_lengths = map(len, tuple_dict.values())
|
|
||||||
|
|
||||||
# create a new config for each parameter setting
|
|
||||||
for values in iter_func(*tuple_dict.values()):
|
for values in iter_func(*tuple_dict.values()):
|
||||||
_config = copy.deepcopy(config)
|
_config = copy.deepcopy(config)
|
||||||
|
|
||||||
# Remove Grid/List Argument
|
|
||||||
del _config[key]
|
del _config[key]
|
||||||
|
|
||||||
# Expand Grid/List Parameters
|
|
||||||
for i, t in enumerate(tuple_dict.keys()):
|
for i, t in enumerate(tuple_dict.keys()):
|
||||||
insert_deep_dictionary(d=_config, t=t, value=values[i])
|
insert_deep_dictionary(d=_config, t=t, value=values[i])
|
||||||
|
|
||||||
@ -488,11 +473,7 @@ def ablative_expand(conf_list):
|
|||||||
for i, key in enumerate(tuple_dict):
|
for i, key in enumerate(tuple_dict):
|
||||||
for val in tuple_dict[key]:
|
for val in tuple_dict[key]:
|
||||||
_config = copy.deepcopy(config)
|
_config = copy.deepcopy(config)
|
||||||
|
insert_deep_dictionary(_config, key, val)
|
||||||
insert_deep_dictionary(
|
|
||||||
_config, key, val
|
|
||||||
)
|
|
||||||
|
|
||||||
_config = extend_config_name(_config, [_param_names[i]], [val])
|
_config = extend_config_name(_config, [_param_names[i]], [val])
|
||||||
combined_configs.append(_config)
|
combined_configs.append(_config)
|
||||||
return combined_configs
|
return combined_configs
|
||||||
@ -504,18 +485,16 @@ def flatten_dict_to_tuple_keys(d: MutableMapping):
|
|||||||
if isinstance(v, MutableMapping):
|
if isinstance(v, MutableMapping):
|
||||||
sub_dict = flatten_dict_to_tuple_keys(v)
|
sub_dict = flatten_dict_to_tuple_keys(v)
|
||||||
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
|
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
|
||||||
|
|
||||||
elif isinstance(v, MutableSequence):
|
elif isinstance(v, MutableSequence):
|
||||||
flat_dict[(k,)] = v
|
flat_dict[(k,)] = v
|
||||||
|
|
||||||
return flat_dict
|
return flat_dict
|
||||||
|
|
||||||
|
|
||||||
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||||
if type(t) is tuple:
|
if type(t) is tuple:
|
||||||
if len(t) == 1: # tuple contains only one key
|
if len(t) == 1:
|
||||||
d[t[0]] = value
|
d[t[0]] = value
|
||||||
else: # tuple contains more than one key
|
else:
|
||||||
if t[0] not in d:
|
if t[0] not in d:
|
||||||
d[t[0]] = dict()
|
d[t[0]] = dict()
|
||||||
insert_deep_dictionary(d[t[0]], t[1:], value)
|
insert_deep_dictionary(d[t[0]], t[1:], value)
|
||||||
@ -525,11 +504,11 @@ def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
|||||||
|
|
||||||
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
||||||
if type(t) is tuple:
|
if type(t) is tuple:
|
||||||
if len(t) == 1: # tuple contains only one key
|
if len(t) is 1:
|
||||||
if t[0] not in d:
|
if t[0] not in d:
|
||||||
d[t[0]] = []
|
d[t[0]] = []
|
||||||
d[t[0]].append(value)
|
d[t[0]].append(value)
|
||||||
else: # tuple contains more than one key
|
else:
|
||||||
if t[0] not in d:
|
if t[0] not in d:
|
||||||
d[t[0]] = dict()
|
d[t[0]] = dict()
|
||||||
append_deep_dictionary(d[t[0]], t[1:], value)
|
append_deep_dictionary(d[t[0]], t[1:], value)
|
||||||
@ -539,15 +518,12 @@ def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
|||||||
|
|
||||||
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
|
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
|
||||||
_converted_name = convert_param_names(param_names, values)
|
_converted_name = convert_param_names(param_names, values)
|
||||||
|
|
||||||
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
|
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def convert_param_names(_param_names: list, values: list) -> str:
|
def convert_param_names(_param_names: list, values: list) -> str:
|
||||||
_converted_name = '_'.join("{}{}".format(
|
_converted_name = '_'.join("{}{}".format(shorten_param(k), v) for k, v in zip(_param_names, values))
|
||||||
shorten_param(k), v) for k, v in zip(_param_names, values))
|
|
||||||
# _converted_name = re.sub("[' \[\],()]", '', _converted_name)
|
|
||||||
_converted_name = re.sub("[' ]", '', _converted_name)
|
_converted_name = re.sub("[' ]", '', _converted_name)
|
||||||
_converted_name = re.sub('["]', '', _converted_name)
|
_converted_name = re.sub('["]', '', _converted_name)
|
||||||
_converted_name = re.sub("[(\[]", '_', _converted_name)
|
_converted_name = re.sub("[(\[]", '_', _converted_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user