Allow configuring reps per version

This commit is contained in:
Dominik Moritz Roth 2023-07-29 14:28:23 +02:00
parent 94408f6b08
commit 9260ea01f3

View File

@ -38,6 +38,7 @@ else:
class Slate():
def __init__(self, runners):
self.runners = {
'void': Void_Runner,
'printConfig': Print_Config_Runner,
'pdb': PDB_Runner,
}
@ -164,30 +165,31 @@ class Slate():
self._version = sha
return self._version
def _calc_num_jobs(self, schedC):
def _calc_num_jobs(self, schedC, num_conv_versions):
schedulerC = copy.deepcopy(schedC)
reps = self.consume(schedulerC, 'repetitions', 1)
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedC, job_id):
def _reps_for_job(self, schedC, job_id, num_conv_versions):
schedulerC = copy.deepcopy(schedC)
num_jobs = self._calc_num_jobs(schedulerC)
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
reps = self.consume(schedulerC, 'repetitions', 1)
if job_id == None:
return list(range(0, reps))
reps_for_job = [[]] * num_jobs
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id-1]
return reps_for_job[job_id]
def run_local(self, filename, name, job_id, sweep_id):
config = self.load_config(filename, name)
num_conv_versions = self._get_num_conv_versions(config)
schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id)
rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions)
self.sweep_id = sweep_id
self._init_sweep(config)
self._fork_processes(config, rep_ids)
@ -199,6 +201,8 @@ class Slate():
schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name')
num_conv_versions = self._get_num_conv_versions(config)
# Pre Validation
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
if self.consume(slurmC, 'pre_validate', True):
@ -217,7 +221,7 @@ class Slate():
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC)
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
@ -335,12 +339,20 @@ class Slate():
print(msg)
orig_config = {}
def _make_config_for_run(self, config, r):
def _make_configs_for_runs(self, config):
c = copy.deepcopy(config)
grid_versions = self._make_grid_versions(c)
all_versions = self._make_ablative_versions(c, grid_versions)
return all_versions
def _get_num_conv_versions(self, config):
return len(self._make_configs_for_runs(config))
def _make_config_for_run(self, config, r):
all_versions = self._make_configs_for_runs(config)
i = r % len(all_versions)
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
cur_conf = all_versions[i]
@ -523,6 +535,13 @@ class Print_Config_Runner(Slate_Runner):
del config[k]
class Void_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
for k in list(config.keys()):
del config[k]
class PDB_Runner(Slate_Runner):
def run(self, run):
d()