Allow configuring reps per version

This commit is contained in:
Dominik Moritz Roth 2023-07-29 14:28:23 +02:00
parent 94408f6b08
commit 9260ea01f3

View File

@ -38,6 +38,7 @@ else:
class Slate(): class Slate():
def __init__(self, runners): def __init__(self, runners):
self.runners = { self.runners = {
'void': Void_Runner,
'printConfig': Print_Config_Runner, 'printConfig': Print_Config_Runner,
'pdb': PDB_Runner, 'pdb': PDB_Runner,
} }
@ -164,30 +165,31 @@ class Slate():
self._version = sha self._version = sha
return self._version return self._version
def _calc_num_jobs(self, schedC): def _calc_num_jobs(self, schedC, num_conv_versions):
schedulerC = copy.deepcopy(schedC) schedulerC = copy.deepcopy(schedC)
reps = self.consume(schedulerC, 'repetitions', 1) reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1) agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1) reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job) jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed return jobs_needed
def _reps_for_job(self, schedC, job_id): def _reps_for_job(self, schedC, job_id, num_conv_versions):
schedulerC = copy.deepcopy(schedC) schedulerC = copy.deepcopy(schedC)
num_jobs = self._calc_num_jobs(schedulerC) num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
reps = self.consume(schedulerC, 'repetitions', 1) reps = self.consume(schedulerC, 'repetitions', 1)
if job_id == None: if job_id == None:
return list(range(0, reps)) return list(range(0, reps))
reps_for_job = [[]] * num_jobs reps_for_job = [[]] * num_jobs
for i in range(reps): for i in range(reps):
reps_for_job[i % num_jobs].append(i) reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id-1] return reps_for_job[job_id]
def run_local(self, filename, name, job_id, sweep_id): def run_local(self, filename, name, job_id, sweep_id):
config = self.load_config(filename, name) config = self.load_config(filename, name)
num_conv_versions = self._get_num_conv_versions(config)
schedulerC = copy.deepcopy(config.get('scheduler', {})) schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id) rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions)
self.sweep_id = sweep_id self.sweep_id = sweep_id
self._init_sweep(config) self._init_sweep(config)
self._fork_processes(config, rep_ids) self._fork_processes(config, rep_ids)
@ -199,6 +201,8 @@ class Slate():
schedC = self.consume(config, 'scheduler') schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name') s_name = self.consume(slurmC, 'name')
num_conv_versions = self._get_num_conv_versions(config)
# Pre Validation # Pre Validation
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True) runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
if self.consume(slurmC, 'pre_validate', True): if self.consume(slurmC, 'pre_validate', True):
@ -217,7 +221,7 @@ class Slate():
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'] sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
script = "\n".join(sh_lines) script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC) num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
last_job_idx = num_jobs - 1 last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
@ -335,12 +339,20 @@ class Slate():
print(msg) print(msg)
orig_config = {} orig_config = {}
def _make_config_for_run(self, config, r): def _make_configs_for_runs(self, config):
c = copy.deepcopy(config) c = copy.deepcopy(config)
grid_versions = self._make_grid_versions(c) grid_versions = self._make_grid_versions(c)
all_versions = self._make_ablative_versions(c, grid_versions) all_versions = self._make_ablative_versions(c, grid_versions)
return all_versions
def _get_num_conv_versions(self, config):
return len(self._make_configs_for_runs(config))
def _make_config_for_run(self, config, r):
all_versions = self._make_configs_for_runs(config)
i = r % len(all_versions) i = r % len(all_versions)
print(f'[d] Running version {i}/{len(all_versions)} in run {r}') print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
cur_conf = all_versions[i] cur_conf = all_versions[i]
@ -523,6 +535,13 @@ class Print_Config_Runner(Slate_Runner):
del config[k] del config[k]
class Void_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
for k in list(config.keys()):
del config[k]
class PDB_Runner(Slate_Runner): class PDB_Runner(Slate_Runner):
def run(self, run): def run(self, run):
d() d()