diff --git a/slate/slate.py b/slate/slate.py index a697152..d24c987 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -16,8 +16,8 @@ import pdb d = pdb.set_trace REQUIRE_CONFIG_CONSUMED = False -WANDB_START_METHOD = 'fork' -REINIT = True +DEFAULT_START_METHOD = 'fork' +DEFAULT_REINIT = True Parallelization_Primitive = Thread # Process @@ -29,11 +29,8 @@ except ImportError: else: slurm_avaible = True -# TODO: Implement Slurm -# TODO: Implement Parallel # TODO: Implement Testing # TODO: Implement Ablative -# TODO: Implement PCA class Slate(): @@ -243,11 +240,13 @@ class Slate(): config = copy.deepcopy(orig_config) if self.consume(config, 'sweep.enable', False): sweepC = self.consume(config, 'sweep') - project = self.consume(copy.deepcopy(config['wandb']), 'project') + wandbC = copy.deepcopy(config['wandb']) + project = self.consume(wandbC, 'project') sweep_id = wandb.sweep( sweep=sweepC, project=project, - settings=wandb.Settings(start_method=WANDB_START_METHOD) + reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), + settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)), ) wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: @@ -266,8 +265,8 @@ class Slate(): with wandb.init( project=project, config=copy.deepcopy(config), - reinit=REINIT, - settings=wandb.Settings(start_method=WANDB_START_METHOD), + reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), + settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)), **wandbC ) as run: runner = Runner(self, config) @@ -290,8 +289,8 @@ class Slate(): with wandb.init( project=project, - reinit=REINIT, - settings=wandb.Settings(start_method=WANDB_START_METHOD), + reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), + settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)), **wandbC ) as run: config = copy.deepcopy(orig_config)