Only init sweep once (on login-node for slurm)
This commit is contained in:
parent
966dfbcfb8
commit
cb4537e5b9
@ -169,10 +169,11 @@ class Slate():
|
|||||||
reps_for_job[i % num_jobs].append(i)
|
reps_for_job[i % num_jobs].append(i)
|
||||||
return reps_for_job[job_id-1]
|
return reps_for_job[job_id-1]
|
||||||
|
|
||||||
def run_local(self, filename, name, job_id):
|
def run_local(self, filename, name, job_id, sweep_id):
|
||||||
config = self.load_config(filename, name)
|
config = self.load_config(filename, name)
|
||||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||||
rep_ids = self._reps_for_job(schedulerC, job_id)
|
rep_ids = self._reps_for_job(schedulerC, job_id)
|
||||||
|
self.sweep_id = sweep_id
|
||||||
self._fork_processes(config, rep_ids)
|
self._fork_processes(config, rep_ids)
|
||||||
|
|
||||||
def run_slurm(self, filename, name):
|
def run_slurm(self, filename, name):
|
||||||
@ -189,12 +190,14 @@ class Slate():
|
|||||||
runner = Runner(self, config)
|
runner = Runner(self, config)
|
||||||
runner.setup()
|
runner.setup()
|
||||||
|
|
||||||
|
self._init_sweep(config)
|
||||||
|
|
||||||
python_script = 'main.py'
|
python_script = 'main.py'
|
||||||
sh_lines = ['#!/bin/bash']
|
sh_lines = ['#!/bin/bash']
|
||||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||||
if venv := self.consume(slurmC, 'venv', False):
|
if venv := self.consume(slurmC, 'venv', False):
|
||||||
sh_lines += [f'source activate {venv}']
|
sh_lines += [f'source activate {venv}']
|
||||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID -s {self.sweep_id}']
|
||||||
script = "\n".join(sh_lines)
|
script = "\n".join(sh_lines)
|
||||||
|
|
||||||
num_jobs = self._calc_num_jobs(schedC)
|
num_jobs = self._calc_num_jobs(schedC)
|
||||||
@ -218,7 +221,7 @@ class Slate():
|
|||||||
|
|
||||||
if num_p == 1:
|
if num_p == 1:
|
||||||
print('[i] Running within main thread')
|
print('[i] Running within main thread')
|
||||||
self._run_single(config, rep_ids=rep_ids, p_ind=0)
|
self._run_process(config, rep_ids=rep_ids, p_ind=0)
|
||||||
return
|
return
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
@ -238,19 +241,24 @@ class Slate():
|
|||||||
proc.join()
|
proc.join()
|
||||||
print(f'[i] All threads/processes have terminated')
|
print(f'[i] All threads/processes have terminated')
|
||||||
|
|
||||||
def _run_process(self, orig_config, rep_ids, p_ind):
|
def _init_sweep(self, config):
|
||||||
config = copy.deepcopy(orig_config)
|
|
||||||
if self.consume(config, 'sweep.enable', False):
|
if self.consume(config, 'sweep.enable', False):
|
||||||
sweepC = self.consume(config, 'sweep')
|
sweepC = self.consume(config, 'sweep')
|
||||||
wandbC = copy.deepcopy(config['wandb'])
|
wandbC = copy.deepcopy(config['wandb'])
|
||||||
project = self.consume(wandbC, 'project')
|
project = self.consume(wandbC, 'project')
|
||||||
sweep_id = wandb.sweep(
|
|
||||||
|
self.sweep_id = wandb.sweep(
|
||||||
sweep=sweepC,
|
sweep=sweepC,
|
||||||
project=project,
|
project=project,
|
||||||
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
||||||
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
|
settings=wandb.Settings(start_method=self.consume(wandbC, 'start_method', DEFAULT_START_METHOD)),
|
||||||
)
|
)
|
||||||
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
|
||||||
|
def _run_process(self, orig_config, rep_ids, p_ind):
|
||||||
|
config = copy.deepcopy(orig_config)
|
||||||
|
if self.consume(config, 'sweep.enable', False):
|
||||||
|
wandbC = copy.deepcopy(config['wandb'])
|
||||||
|
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
||||||
else:
|
else:
|
||||||
self.consume(config, 'sweep', {})
|
self.consume(config, 'sweep', {})
|
||||||
self._run_single(config, rep_ids, p_ind=p_ind)
|
self._run_single(config, rep_ids, p_ind=p_ind)
|
||||||
@ -319,6 +327,7 @@ class Slate():
|
|||||||
parser.add_argument("-s", "--slurm", action="store_true")
|
parser.add_argument("-s", "--slurm", action="store_true")
|
||||||
parser.add_argument("-w", "--worker", action="store_true")
|
parser.add_argument("-w", "--worker", action="store_true")
|
||||||
parser.add_argument("-j", "--job_id", default=None, type=int)
|
parser.add_argument("-j", "--job_id", default=None, type=int)
|
||||||
|
parser.add_argument("-s", "--sweep_id", default=None, type=str)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -332,7 +341,7 @@ class Slate():
|
|||||||
if args.slurm:
|
if args.slurm:
|
||||||
self.run_slurm(args.config_file, args.experiment)
|
self.run_slurm(args.config_file, args.experiment)
|
||||||
else:
|
else:
|
||||||
self.run_local(args.config_file, args.experiment, args.job_id)
|
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
|
||||||
|
|
||||||
|
|
||||||
class Slate_Runner():
|
class Slate_Runner():
|
||||||
|
Loading…
Reference in New Issue
Block a user