diff --git a/slate/slate.py b/slate/slate.py index cd33348..ac76823 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -13,6 +13,7 @@ import pdb d = pdb.set_trace REQUIRE_CONFIG_CONSUMED = False +WANDB_START_METHOD = 'process' Parallelization_Primitive = Thread # Process @@ -216,7 +217,7 @@ class Slate(): sweep_id = wandb.sweep( sweep=sweepC, project=project, - settings=wandb.Settings(start_method="thread") + settings=wandb.Settings(start_method=WANDB_START_METHOD) ) wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) else: @@ -235,7 +236,7 @@ class Slate(): with wandb.init( project=project, config=copy.deepcopy(config), - settings=wandb.Settings(start_method="thread"), + settings=wandb.Settings(start_method=WANDB_START_METHOD), **wandbC ) as run: runner(self, run, config) @@ -256,7 +257,7 @@ class Slate(): with wandb.init( project=project, - settings=wandb.Settings(start_method="thread"), + settings=wandb.Settings(start_method=WANDB_START_METHOD), **wandbC ) as run: config = copy.deepcopy(orig_config)