diff --git a/slate/slate.py b/slate/slate.py index 6234be0..7b5d550 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -38,6 +38,7 @@ else: class Slate(): def __init__(self, runners): self.runners = { + 'void': Void_Runner, 'printConfig': Print_Config_Runner, 'pdb': PDB_Runner, } @@ -164,30 +165,31 @@ class Slate(): self._version = sha return self._version - def _calc_num_jobs(self, schedC): + def _calc_num_jobs(self, schedC, num_conv_versions): schedulerC = copy.deepcopy(schedC) - reps = self.consume(schedulerC, 'repetitions', 1) + reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions) agents_per_job = self.consume(schedulerC, 'agents_per_job', 1) reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1) reps_per_job = reps_per_agent * agents_per_job jobs_needed = math.ceil(reps / reps_per_job) return jobs_needed - def _reps_for_job(self, schedC, job_id): + def _reps_for_job(self, schedC, job_id, num_conv_versions): schedulerC = copy.deepcopy(schedC) - num_jobs = self._calc_num_jobs(schedulerC) + num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions) reps = self.consume(schedulerC, 'repetitions', 1) if job_id == None: return list(range(0, reps)) reps_for_job = [[]] * num_jobs for i in range(reps): reps_for_job[i % num_jobs].append(i) - return reps_for_job[job_id-1] + return reps_for_job[job_id] def run_local(self, filename, name, job_id, sweep_id): config = self.load_config(filename, name) + num_conv_versions = self._get_num_conv_versions(config) schedulerC = copy.deepcopy(config.get('scheduler', {})) - rep_ids = self._reps_for_job(schedulerC, job_id) + rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions) self.sweep_id = sweep_id self._init_sweep(config) self._fork_processes(config, rep_ids) @@ -199,6 +201,8 @@ class Slate(): schedC = self.consume(config, 'scheduler') s_name = self.consume(slurmC, 'name') + num_conv_versions = self._get_num_conv_versions(config) + # Pre Validation runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True) if self.consume(slurmC, 'pre_validate', True): @@ -217,7 +221,7 @@ class Slate(): sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'] script = "\n".join(sh_lines) - num_jobs = self._calc_num_jobs(schedC) + num_jobs = self._calc_num_jobs(schedC, num_conv_versions) last_job_idx = num_jobs - 1 num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) @@ -335,12 +339,20 @@ class Slate(): print(msg) orig_config = {} - def _make_config_for_run(self, config, r): + def _make_configs_for_runs(self, config): c = copy.deepcopy(config) grid_versions = self._make_grid_versions(c) all_versions = self._make_ablative_versions(c, grid_versions) + return all_versions + + def _get_num_conv_versions(self, config): + return len(self._make_configs_for_runs(config)) + + def _make_config_for_run(self, config, r): + all_versions = self._make_configs_for_runs(config) + i = r % len(all_versions) print(f'[d] Running version {i}/{len(all_versions)} in run {r}') cur_conf = all_versions[i] @@ -523,6 +535,13 @@ class Print_Config_Runner(Slate_Runner): del config[k] +class Void_Runner(Slate_Runner): + def run(self, run): + slate, config = self.slate, self.config + for k in list(config.keys()): + del config[k] + + class PDB_Runner(Slate_Runner): def run(self, run): d()