From 74b06d92e795919078edb985e48a3c5d47b22398 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 29 Jul 2023 11:53:29 +0200 Subject: [PATCH] ChatGPT lied (Revert "CHatGPT tells me i need to init BEFORE I sweep") This reverts commit 373346b5894a105c0f60ca7885aab760c8736257. --- slate/slate.py | 47 +++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/slate/slate.py b/slate/slate.py index 285a94c..938f030 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -263,16 +263,8 @@ class Slate(): def _run_process(self, orig_config, rep_ids, p_ind): config = copy.deepcopy(orig_config) if self.consume(config, 'sweep.enable', False): - wandbC = self.consume(copy.deepcopy(orig_config), 'wandb', {}, expand=True) - project = self.consume(wandbC, 'project') - - with wandb.init( - project=project, - reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), - settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), - **wandbC - ) as run: - wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, run, p_ind=p_ind), count=len(rep_ids)) + wandbC = copy.deepcopy(config['wandb']) + wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: self.consume(config, 'sweep', {}) self._run_single(config, rep_ids, p_ind=p_ind) @@ -305,24 +297,31 @@ class Slate(): print(msg) orig_config = config - def _run_from_sweep(self, orig_config, run, p_ind): - runnerName = self.consume(orig_config, 'runner') + def _run_from_sweep(self, orig_config, p_ind): + runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) + project = self.consume(wandbC, 'project') Runner = self.runners[runnerName] - config = copy.deepcopy(orig_config) - self.deep_update(config, wandb.config) - run.config = copy.deepcopy(config) - runner = Runner(self, config) - runner.setup() - runner.run(run) + with wandb.init( + project=project, + reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), + settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), + **wandbC + ) as run: + config = copy.deepcopy(orig_config) + self.deep_update(config, wandb.config) + run.config = copy.deepcopy(config) + runner = Runner(self, config) + runner.setup() + runner.run(run) - if config != {}: - msg = ('Config was not completely consumed: ', config) - if REQUIRE_CONFIG_CONSUMED: - raise Exception(msg) - else: - print(msg) + if config != {}: + msg = ('Config was not completely consumed: ', config) + if REQUIRE_CONFIG_CONSUMED: + raise Exception(msg) + else: + print(msg) orig_config = config def from_args(self):