Only init sweep once (on login-node for slurm)
This commit is contained in:
parent
966dfbcfb8
commit
cb4537e5b9
@ -169,10 +169,11 @@ class Slate():
|
||||
reps_for_job[i % num_jobs].append(i)
|
||||
return reps_for_job[job_id-1]
|
||||
|
||||
def run_local(self, filename, name, job_id):
|
||||
def run_local(self, filename, name, job_id, sweep_id):
|
||||
config = self.load_config(filename, name)
|
||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||
rep_ids = self._reps_for_job(schedulerC, job_id)
|
||||
self.sweep_id = sweep_id
|
||||
self._fork_processes(config, rep_ids)
|
||||
|
||||
def run_slurm(self, filename, name):
|
||||
@ -189,12 +190,14 @@ class Slate():
|
||||
runner = Runner(self, config)
|
||||
runner.setup()
|
||||
|
||||
self._init_sweep(config)
|
||||
|
||||
python_script = 'main.py'
|
||||
sh_lines = ['#!/bin/bash']
|
||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||
if venv := self.consume(slurmC, 'venv', False):
|
||||
sh_lines += [f'source activate {venv}']
|
||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID -s {self.sweep_id}']
|
||||
script = "\n".join(sh_lines)
|
||||
|
||||
num_jobs = self._calc_num_jobs(schedC)
|
||||
@ -218,7 +221,7 @@ class Slate():
|
||||
|
||||
if num_p == 1:
|
||||
print('[i] Running within main thread')
|
||||
self._run_single(config, rep_ids=rep_ids, p_ind=0)
|
||||
self._run_process(config, rep_ids=rep_ids, p_ind=0)
|
||||
return
|
||||
|
||||
procs = []
|
||||
@ -238,19 +241,24 @@ class Slate():
|
||||
proc.join()
|
||||
print(f'[i] All threads/processes have terminated')
|
||||
|
||||
def _run_process(self, orig_config, rep_ids, p_ind):
|
||||
config = copy.deepcopy(orig_config)
|
||||
def _init_sweep(self, config):
|
||||
if self.consume(config, 'sweep.enable', False):
|
||||
sweepC = self.consume(config, 'sweep')
|
||||
wandbC = copy.deepcopy(config['wandb'])
|
||||
project = self.consume(wandbC, 'project')
|
||||
sweep_id = wandb.sweep(
|
||||
|
||||
self.sweep_id = wandb.sweep(
|
||||
sweep=sweepC,
|
||||
project=project,
|
||||
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
||||
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
|
||||
)
|
||||
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
||||
|
||||
def _run_process(self, orig_config, rep_ids, p_ind):
|
||||
config = copy.deepcopy(orig_config)
|
||||
if self.consume(config, 'sweep.enable', False):
|
||||
wandbC = copy.deepcopy(config['wandb'])
|
||||
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
||||
else:
|
||||
self.consume(config, 'sweep', {})
|
||||
self._run_single(config, rep_ids, p_ind=p_ind)
|
||||
@ -319,6 +327,7 @@ class Slate():
|
||||
parser.add_argument("-s", "--slurm", action="store_true")
|
||||
parser.add_argument("-w", "--worker", action="store_true")
|
||||
parser.add_argument("-j", "--job_id", default=None, type=int)
|
||||
parser.add_argument("-s", "--sweep_id", default=None, type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -332,7 +341,7 @@ class Slate():
|
||||
if args.slurm:
|
||||
self.run_slurm(args.config_file, args.experiment)
|
||||
else:
|
||||
self.run_local(args.config_file, args.experiment, args.job_id)
|
||||
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
|
||||
|
||||
|
||||
class Slate_Runner():
|
||||
|
Loading…
Reference in New Issue
Block a user