Make more wandb options configurable

This commit is contained in:
Dominik Moritz Roth 2023-07-27 11:29:06 +02:00
parent eddc5e6092
commit 72d28de17b

View File

@ -16,8 +16,8 @@ import pdb
d = pdb.set_trace d = pdb.set_trace
REQUIRE_CONFIG_CONSUMED = False REQUIRE_CONFIG_CONSUMED = False
WANDB_START_METHOD = 'fork' DEFAULT_START_METHOD = 'fork'
REINIT = True DEFAULT_REINIT = True
Parallelization_Primitive = Thread # Process Parallelization_Primitive = Thread # Process
@ -29,11 +29,8 @@ except ImportError:
else: else:
slurm_avaible = True slurm_avaible = True
# TODO: Implement Slurm
# TODO: Implement Parallel
# TODO: Implement Testing # TODO: Implement Testing
# TODO: Implement Ablative # TODO: Implement Ablative
# TODO: Implement PCA
class Slate(): class Slate():
@ -243,11 +240,13 @@ class Slate():
config = copy.deepcopy(orig_config) config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False): if self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep') sweepC = self.consume(config, 'sweep')
project = self.consume(copy.deepcopy(config['wandb']), 'project') wandbC = copy.deepcopy(config['wandb'])
project = self.consume(wandbC, 'project')
sweep_id = wandb.sweep( sweep_id = wandb.sweep(
sweep=sweepC, sweep=sweepC,
project=project, project=project,
settings=wandb.Settings(start_method=WANDB_START_METHOD) reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
) )
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else: else:
@ -266,8 +265,8 @@ class Slate():
with wandb.init( with wandb.init(
project=project, project=project,
config=copy.deepcopy(config), config=copy.deepcopy(config),
reinit=REINIT, reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(start_method=WANDB_START_METHOD), settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
**wandbC **wandbC
) as run: ) as run:
runner = Runner(self, config) runner = Runner(self, config)
@ -290,8 +289,8 @@ class Slate():
with wandb.init( with wandb.init(
project=project, project=project,
reinit=REINIT, reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(start_method=WANDB_START_METHOD), settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
**wandbC **wandbC
) as run: ) as run:
config = copy.deepcopy(orig_config) config = copy.deepcopy(orig_config)