diff --git a/slate/slate.py b/slate/slate.py index 938f030..285a94c 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -263,8 +263,16 @@ class Slate(): def _run_process(self, orig_config, rep_ids, p_ind): config = copy.deepcopy(orig_config) if self.consume(config, 'sweep.enable', False): - wandbC = copy.deepcopy(config['wandb']) - wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) + wandbC = self.consume(copy.deepcopy(orig_config), 'wandb', {}, expand=True) + project = self.consume(wandbC, 'project') + + with wandb.init( + project=project, + reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), + settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), + **wandbC + ) as run: + wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, run, p_ind=p_ind), count=len(rep_ids)) else: self.consume(config, 'sweep', {}) self._run_single(config, rep_ids, p_ind=p_ind) @@ -297,31 +305,24 @@ class Slate(): print(msg) orig_config = config - def _run_from_sweep(self, orig_config, p_ind): - runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) - project = self.consume(wandbC, 'project') + def _run_from_sweep(self, orig_config, run, p_ind): + runnerName = self.consume(orig_config, 'runner') Runner = self.runners[runnerName] - with wandb.init( - project=project, - reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), - settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), - **wandbC - ) as run: - config = copy.deepcopy(orig_config) - self.deep_update(config, wandb.config) - run.config = copy.deepcopy(config) - runner = Runner(self, config) - runner.setup() - runner.run(run) + config = copy.deepcopy(orig_config) + self.deep_update(config, wandb.config) + run.config = copy.deepcopy(config) + runner = Runner(self, config) + runner.setup() + runner.run(run) - if config != {}: - msg = ('Config was not completely consumed: ', config) - if REQUIRE_CONFIG_CONSUMED: - raise Exception(msg) - else: - print(msg) + if config != {}: + msg = ('Config was not completely consumed: ', config) + if REQUIRE_CONFIG_CONSUMED: + raise Exception(msg) + else: + print(msg) orig_config = config def from_args(self):