diff --git a/slate/slate.py b/slate/slate.py index 871219e..3a0e1ae 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -174,6 +174,7 @@ class Slate(): schedulerC = copy.deepcopy(config.get('scheduler', {})) rep_ids = self._reps_for_job(schedulerC, job_id) self.sweep_id = sweep_id + self._init_sweep(config) self._fork_processes(config, rep_ids) def run_slurm(self, filename, name): @@ -184,13 +185,14 @@ class Slate(): s_name = self.consume(slurmC, 'name') # Pre Validation - runnerName, _ = self.consume(config, 'runner'), self.consume(config, 'wandb', {}, expand=True) + runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True) if self.consume(slurmC, 'pre_validate', True): Runner = self.runners[runnerName] runner = Runner(self, config) runner.setup() self._init_sweep(config) + self.consume(config, 'wandb') python_script = 'main.py' sh_lines = ['#!/bin/bash'] @@ -242,7 +244,7 @@ class Slate(): print(f'[i] All threads/processes have terminated') def _init_sweep(self, config): - if self.consume(config, 'sweep.enable', False): + if self.sweep_id == None and self.consume(config, 'sweep.enable', False): sweepC = self.consume(config, 'sweep') wandbC = copy.deepcopy(config['wandb']) project = self.consume(wandbC, 'project')