ChatGPT lied (Revert "CHatGPT tells me i need to init BEFORE I sweep")

This reverts commit 373346b589.
This commit is contained in:
Dominik Moritz Roth 2023-07-29 11:53:29 +02:00
parent 373346b589
commit 74b06d92e7

View File

@ -263,16 +263,8 @@ class Slate():
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False):
wandbC = self.consume(copy.deepcopy(orig_config), 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
with wandb.init(
project=project,
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
**wandbC
) as run:
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, run, p_ind=p_ind), count=len(rep_ids))
wandbC = copy.deepcopy(config['wandb'])
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else:
self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind)
@ -305,24 +297,31 @@ class Slate():
print(msg)
orig_config = config
def _run_from_sweep(self, orig_config, run, p_ind):
runnerName = self.consume(orig_config, 'runner')
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
Runner = self.runners[runnerName]
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
run.config = copy.deepcopy(config)
runner = Runner(self, config)
runner.setup()
runner.run(run)
with wandb.init(
project=project,
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
**wandbC
) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
run.config = copy.deepcopy(config)
runner = Runner(self, config)
runner.setup()
runner.run(run)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
if config != {}:
msg = ('Config was not completely consumed: ', config)
if REQUIRE_CONFIG_CONSUMED:
raise Exception(msg)
else:
print(msg)
orig_config = config
def from_args(self):