Only init sweep once (on login-node for slurm)

This commit is contained in:
Dominik Moritz Roth 2023-07-27 12:34:36 +02:00
parent 966dfbcfb8
commit cb4537e5b9

View File

@ -169,10 +169,11 @@ class Slate():
reps_for_job[i % num_jobs].append(i) reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id-1] return reps_for_job[job_id-1]
def run_local(self, filename, name, job_id): def run_local(self, filename, name, job_id, sweep_id):
config = self.load_config(filename, name) config = self.load_config(filename, name)
schedulerC = copy.deepcopy(config.get('scheduler', {})) schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id) rep_ids = self._reps_for_job(schedulerC, job_id)
self.sweep_id = sweep_id
self._fork_processes(config, rep_ids) self._fork_processes(config, rep_ids)
def run_slurm(self, filename, name): def run_slurm(self, filename, name):
@ -189,12 +190,14 @@ class Slate():
runner = Runner(self, config) runner = Runner(self, config)
runner.setup() runner.setup()
self._init_sweep(config)
python_script = 'main.py' python_script = 'main.py'
sh_lines = ['#!/bin/bash'] sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', []) sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False): if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}'] sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID -s {self.sweep_id}']
script = "\n".join(sh_lines) script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC) num_jobs = self._calc_num_jobs(schedC)
@ -218,7 +221,7 @@ class Slate():
if num_p == 1: if num_p == 1:
print('[i] Running within main thread') print('[i] Running within main thread')
self._run_single(config, rep_ids=rep_ids, p_ind=0) self._run_process(config, rep_ids=rep_ids, p_ind=0)
return return
procs = [] procs = []
@ -238,19 +241,24 @@ class Slate():
proc.join() proc.join()
print(f'[i] All threads/processes have terminated') print(f'[i] All threads/processes have terminated')
def _run_process(self, orig_config, rep_ids, p_ind): def _init_sweep(self, config):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False): if self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep') sweepC = self.consume(config, 'sweep')
wandbC = copy.deepcopy(config['wandb']) wandbC = copy.deepcopy(config['wandb'])
project = self.consume(wandbC, 'project') project = self.consume(wandbC, 'project')
sweep_id = wandb.sweep(
self.sweep_id = wandb.sweep(
sweep=sweepC, sweep=sweepC,
project=project, project=project,
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)), settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
) )
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False):
wandbC = copy.deepcopy(config['wandb'])
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else: else:
self.consume(config, 'sweep', {}) self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind) self._run_single(config, rep_ids, p_ind=p_ind)
@ -319,6 +327,7 @@ class Slate():
parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_id", default=None, type=int) parser.add_argument("-j", "--job_id", default=None, type=int)
parser.add_argument("-s", "--sweep_id", default=None, type=str)
args = parser.parse_args() args = parser.parse_args()
@ -332,7 +341,7 @@ class Slate():
if args.slurm: if args.slurm:
self.run_slurm(args.config_file, args.experiment) self.run_slurm(args.config_file, args.experiment)
else: else:
self.run_local(args.config_file, args.experiment, args.job_id) self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
class Slate_Runner(): class Slate_Runner():