Allow configuring reps per version
This commit is contained in:
parent
94408f6b08
commit
9260ea01f3
@ -38,6 +38,7 @@ else:
|
|||||||
class Slate():
|
class Slate():
|
||||||
def __init__(self, runners):
|
def __init__(self, runners):
|
||||||
self.runners = {
|
self.runners = {
|
||||||
|
'void': Void_Runner,
|
||||||
'printConfig': Print_Config_Runner,
|
'printConfig': Print_Config_Runner,
|
||||||
'pdb': PDB_Runner,
|
'pdb': PDB_Runner,
|
||||||
}
|
}
|
||||||
@ -164,30 +165,31 @@ class Slate():
|
|||||||
self._version = sha
|
self._version = sha
|
||||||
return self._version
|
return self._version
|
||||||
|
|
||||||
def _calc_num_jobs(self, schedC):
|
def _calc_num_jobs(self, schedC, num_conv_versions):
|
||||||
schedulerC = copy.deepcopy(schedC)
|
schedulerC = copy.deepcopy(schedC)
|
||||||
reps = self.consume(schedulerC, 'repetitions', 1)
|
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
||||||
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
||||||
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
||||||
reps_per_job = reps_per_agent * agents_per_job
|
reps_per_job = reps_per_agent * agents_per_job
|
||||||
jobs_needed = math.ceil(reps / reps_per_job)
|
jobs_needed = math.ceil(reps / reps_per_job)
|
||||||
return jobs_needed
|
return jobs_needed
|
||||||
|
|
||||||
def _reps_for_job(self, schedC, job_id):
|
def _reps_for_job(self, schedC, job_id, num_conv_versions):
|
||||||
schedulerC = copy.deepcopy(schedC)
|
schedulerC = copy.deepcopy(schedC)
|
||||||
num_jobs = self._calc_num_jobs(schedulerC)
|
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
|
||||||
reps = self.consume(schedulerC, 'repetitions', 1)
|
reps = self.consume(schedulerC, 'repetitions', 1)
|
||||||
if job_id == None:
|
if job_id == None:
|
||||||
return list(range(0, reps))
|
return list(range(0, reps))
|
||||||
reps_for_job = [[]] * num_jobs
|
reps_for_job = [[]] * num_jobs
|
||||||
for i in range(reps):
|
for i in range(reps):
|
||||||
reps_for_job[i % num_jobs].append(i)
|
reps_for_job[i % num_jobs].append(i)
|
||||||
return reps_for_job[job_id-1]
|
return reps_for_job[job_id]
|
||||||
|
|
||||||
def run_local(self, filename, name, job_id, sweep_id):
|
def run_local(self, filename, name, job_id, sweep_id):
|
||||||
config = self.load_config(filename, name)
|
config = self.load_config(filename, name)
|
||||||
|
num_conv_versions = self._get_num_conv_versions(config)
|
||||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||||
rep_ids = self._reps_for_job(schedulerC, job_id)
|
rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions)
|
||||||
self.sweep_id = sweep_id
|
self.sweep_id = sweep_id
|
||||||
self._init_sweep(config)
|
self._init_sweep(config)
|
||||||
self._fork_processes(config, rep_ids)
|
self._fork_processes(config, rep_ids)
|
||||||
@ -199,6 +201,8 @@ class Slate():
|
|||||||
schedC = self.consume(config, 'scheduler')
|
schedC = self.consume(config, 'scheduler')
|
||||||
s_name = self.consume(slurmC, 'name')
|
s_name = self.consume(slurmC, 'name')
|
||||||
|
|
||||||
|
num_conv_versions = self._get_num_conv_versions(config)
|
||||||
|
|
||||||
# Pre Validation
|
# Pre Validation
|
||||||
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
|
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
|
||||||
if self.consume(slurmC, 'pre_validate', True):
|
if self.consume(slurmC, 'pre_validate', True):
|
||||||
@ -217,7 +221,7 @@ class Slate():
|
|||||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
|
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
|
||||||
script = "\n".join(sh_lines)
|
script = "\n".join(sh_lines)
|
||||||
|
|
||||||
num_jobs = self._calc_num_jobs(schedC)
|
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
|
||||||
|
|
||||||
last_job_idx = num_jobs - 1
|
last_job_idx = num_jobs - 1
|
||||||
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
||||||
@ -335,12 +339,20 @@ class Slate():
|
|||||||
print(msg)
|
print(msg)
|
||||||
orig_config = {}
|
orig_config = {}
|
||||||
|
|
||||||
def _make_config_for_run(self, config, r):
|
def _make_configs_for_runs(self, config):
|
||||||
c = copy.deepcopy(config)
|
c = copy.deepcopy(config)
|
||||||
|
|
||||||
grid_versions = self._make_grid_versions(c)
|
grid_versions = self._make_grid_versions(c)
|
||||||
all_versions = self._make_ablative_versions(c, grid_versions)
|
all_versions = self._make_ablative_versions(c, grid_versions)
|
||||||
|
|
||||||
|
return all_versions
|
||||||
|
|
||||||
|
def _get_num_conv_versions(self, config):
|
||||||
|
return len(self._make_configs_for_runs(config))
|
||||||
|
|
||||||
|
def _make_config_for_run(self, config, r):
|
||||||
|
all_versions = self._make_configs_for_runs(config)
|
||||||
|
|
||||||
i = r % len(all_versions)
|
i = r % len(all_versions)
|
||||||
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
|
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
|
||||||
cur_conf = all_versions[i]
|
cur_conf = all_versions[i]
|
||||||
@ -523,6 +535,13 @@ class Print_Config_Runner(Slate_Runner):
|
|||||||
del config[k]
|
del config[k]
|
||||||
|
|
||||||
|
|
||||||
|
class Void_Runner(Slate_Runner):
|
||||||
|
def run(self, run):
|
||||||
|
slate, config = self.slate, self.config
|
||||||
|
for k in list(config.keys()):
|
||||||
|
del config[k]
|
||||||
|
|
||||||
|
|
||||||
class PDB_Runner(Slate_Runner):
|
class PDB_Runner(Slate_Runner):
|
||||||
def run(self, run):
|
def run(self, run):
|
||||||
d()
|
d()
|
||||||
|
Loading…
Reference in New Issue
Block a user