Allow configuring reps per version
This commit is contained in:
parent
94408f6b08
commit
9260ea01f3
@ -38,6 +38,7 @@ else:
|
||||
class Slate():
|
||||
def __init__(self, runners):
|
||||
self.runners = {
|
||||
'void': Void_Runner,
|
||||
'printConfig': Print_Config_Runner,
|
||||
'pdb': PDB_Runner,
|
||||
}
|
||||
@ -164,30 +165,31 @@ class Slate():
|
||||
self._version = sha
|
||||
return self._version
|
||||
|
||||
def _calc_num_jobs(self, schedC):
|
||||
def _calc_num_jobs(self, schedC, num_conv_versions):
|
||||
schedulerC = copy.deepcopy(schedC)
|
||||
reps = self.consume(schedulerC, 'repetitions', 1)
|
||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
||||
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
||||
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
||||
reps_per_job = reps_per_agent * agents_per_job
|
||||
jobs_needed = math.ceil(reps / reps_per_job)
|
||||
return jobs_needed
|
||||
|
||||
def _reps_for_job(self, schedC, job_id):
|
||||
def _reps_for_job(self, schedC, job_id, num_conv_versions):
|
||||
schedulerC = copy.deepcopy(schedC)
|
||||
num_jobs = self._calc_num_jobs(schedulerC)
|
||||
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
|
||||
reps = self.consume(schedulerC, 'repetitions', 1)
|
||||
if job_id == None:
|
||||
return list(range(0, reps))
|
||||
reps_for_job = [[]] * num_jobs
|
||||
for i in range(reps):
|
||||
reps_for_job[i % num_jobs].append(i)
|
||||
return reps_for_job[job_id-1]
|
||||
return reps_for_job[job_id]
|
||||
|
||||
def run_local(self, filename, name, job_id, sweep_id):
|
||||
config = self.load_config(filename, name)
|
||||
num_conv_versions = self._get_num_conv_versions(config)
|
||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||
rep_ids = self._reps_for_job(schedulerC, job_id)
|
||||
rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions)
|
||||
self.sweep_id = sweep_id
|
||||
self._init_sweep(config)
|
||||
self._fork_processes(config, rep_ids)
|
||||
@ -199,6 +201,8 @@ class Slate():
|
||||
schedC = self.consume(config, 'scheduler')
|
||||
s_name = self.consume(slurmC, 'name')
|
||||
|
||||
num_conv_versions = self._get_num_conv_versions(config)
|
||||
|
||||
# Pre Validation
|
||||
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
|
||||
if self.consume(slurmC, 'pre_validate', True):
|
||||
@ -217,7 +221,7 @@ class Slate():
|
||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
|
||||
script = "\n".join(sh_lines)
|
||||
|
||||
num_jobs = self._calc_num_jobs(schedC)
|
||||
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
|
||||
|
||||
last_job_idx = num_jobs - 1
|
||||
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
||||
@ -335,12 +339,20 @@ class Slate():
|
||||
print(msg)
|
||||
orig_config = {}
|
||||
|
||||
def _make_config_for_run(self, config, r):
|
||||
def _make_configs_for_runs(self, config):
|
||||
c = copy.deepcopy(config)
|
||||
|
||||
grid_versions = self._make_grid_versions(c)
|
||||
all_versions = self._make_ablative_versions(c, grid_versions)
|
||||
|
||||
return all_versions
|
||||
|
||||
def _get_num_conv_versions(self, config):
|
||||
return len(self._make_configs_for_runs(config))
|
||||
|
||||
def _make_config_for_run(self, config, r):
|
||||
all_versions = self._make_configs_for_runs(config)
|
||||
|
||||
i = r % len(all_versions)
|
||||
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
|
||||
cur_conf = all_versions[i]
|
||||
@ -523,6 +535,13 @@ class Print_Config_Runner(Slate_Runner):
|
||||
del config[k]
|
||||
|
||||
|
||||
class Void_Runner(Slate_Runner):
|
||||
def run(self, run):
|
||||
slate, config = self.slate, self.config
|
||||
for k in list(config.keys()):
|
||||
del config[k]
|
||||
|
||||
|
||||
class PDB_Runner(Slate_Runner):
|
||||
def run(self, run):
|
||||
d()
|
||||
|
Loading…
Reference in New Issue
Block a user