From cb4537e5b93b62629e7cb4b755b80e3a8e7fc6e6 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 27 Jul 2023 12:34:36 +0200 Subject: [PATCH] Only init sweep once (on login-node for slurm) --- slate/slate.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/slate/slate.py b/slate/slate.py index f571257..3f73924 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -169,10 +169,11 @@ class Slate(): reps_for_job[i % num_jobs].append(i) return reps_for_job[job_id-1] - def run_local(self, filename, name, job_id): + def run_local(self, filename, name, job_id, sweep_id): config = self.load_config(filename, name) schedulerC = copy.deepcopy(config.get('scheduler', {})) rep_ids = self._reps_for_job(schedulerC, job_id) + self.sweep_id = sweep_id self._fork_processes(config, rep_ids) def run_slurm(self, filename, name): @@ -189,12 +190,14 @@ class Slate(): runner = Runner(self, config) runner.setup() + self._init_sweep(config) + python_script = 'main.py' sh_lines = ['#!/bin/bash'] sh_lines += self.consume(slurmC, 'sh_lines', []) if venv := self.consume(slurmC, 'venv', False): sh_lines += [f'source activate {venv}'] - sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] + sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID -s {self.sweep_id}'] script = "\n".join(sh_lines) num_jobs = self._calc_num_jobs(schedC) @@ -218,7 +221,7 @@ class Slate(): if num_p == 1: print('[i] Running within main thread') - self._run_single(config, rep_ids=rep_ids, p_ind=0) + self._run_process(config, rep_ids=rep_ids, p_ind=0) return procs = [] @@ -238,19 +241,24 @@ class Slate(): proc.join() print(f'[i] All threads/processes have terminated') - def _run_process(self, orig_config, rep_ids, p_ind): - config = copy.deepcopy(orig_config) + def _init_sweep(self, config): if self.consume(config, 'sweep.enable', False): sweepC = self.consume(config, 'sweep') wandbC = copy.deepcopy(config['wandb']) project = self.consume(wandbC, 'project') - sweep_id = wandb.sweep( + + self.sweep_id = wandb.sweep( sweep=sweepC, project=project, reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)), ) - wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) + + def _run_process(self, orig_config, rep_ids, p_ind): + config = copy.deepcopy(orig_config) + if self.consume(config, 'sweep.enable', False): + wandbC = copy.deepcopy(config['wandb']) + wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: self.consume(config, 'sweep', {}) self._run_single(config, rep_ids, p_ind=p_ind) @@ -319,6 +327,7 @@ class Slate(): parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-j", "--job_id", default=None, type=int) + parser.add_argument("-s", "--sweep_id", default=None, type=str) args = parser.parse_args() @@ -332,7 +341,7 @@ class Slate(): if args.slurm: self.run_slurm(args.config_file, args.experiment) else: - self.run_local(args.config_file, args.experiment, args.job_id) + self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id) class Slate_Runner():