Multiple Processes fro Slurm
This commit is contained in:
parent
be76a5363a
commit
5e8d6d2552
130
slate/slate.py
130
slate/slate.py
@ -6,7 +6,7 @@ import random
|
||||
import copy
|
||||
import collections.abc
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing import Process
|
||||
|
||||
import pdb
|
||||
d = pdb.set_trace
|
||||
@ -114,38 +114,33 @@ class Slate():
|
||||
child_keys = '.'.join(keys_arr[1:])
|
||||
return self.consume(child, child_keys, default=default)
|
||||
|
||||
def run_local(self, filename, name, job_num=None):
|
||||
def _calc_num_jobs(self, schedulerC):
|
||||
reps = schedulerC.get('repetitions', 1)
|
||||
agents_per_job = schedulerC.get('agents_per_job', 1)
|
||||
reps_per_agent = schedulerC.get('reps_per_agent', 1)
|
||||
reps_per_job = reps_per_agent * agents_per_job
|
||||
jobs_needed = math.ceil(reps / reps_per_job)
|
||||
return jobs_needed
|
||||
|
||||
def _reps_for_job(self, schedulerC, job_id):
|
||||
reps = schedulerC.get('repetitions', 1)
|
||||
num_jobs = self._calc_num_jobs(schedulerC)
|
||||
reps_for_job = [[]] * num_jobs
|
||||
for i in range(reps):
|
||||
reps_for_job[i % num_jobs].append(i)
|
||||
return reps_for_job[job_id]
|
||||
|
||||
def run_local(self, filename, name, job_id=0):
|
||||
config = self.load_config(filename, name)
|
||||
if self.consume(config, 'sweep.enable', False):
|
||||
sweepC = self.consume(config, 'sweep')
|
||||
project = self.consume(config, 'wandb.project')
|
||||
sweep_id = wandb.sweep(
|
||||
sweep=sweepC,
|
||||
project=project
|
||||
)
|
||||
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
|
||||
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
|
||||
else:
|
||||
self.consume(config, 'sweep', {})
|
||||
self.run_single(config)
|
||||
|
||||
def _run_from_sweep(self, orig_config, runnerName, project, wandbC):
|
||||
runner = self.runners[runnerName]
|
||||
|
||||
with wandb.init(
|
||||
project=project,
|
||||
**wandbC
|
||||
) as run:
|
||||
config = copy.deepcopy(orig_config)
|
||||
self.deep_update(config, wandb.config)
|
||||
runner(self, run, config)
|
||||
|
||||
assert config == {}, ('Config was not completely consumed: ', config)
|
||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||
rep_ids = self._reps_for_job(schedulerC, job_id)
|
||||
self._fork_processes(config, rep_ids)
|
||||
|
||||
def run_slurm(self, filename, name):
|
||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
||||
config = self.load_config(filename, name)
|
||||
slurmC = self.consume(config, 'slurm')
|
||||
schedC = self.consume(config, 'scheduler')
|
||||
s_name = self.consume(slurmC, 'name')
|
||||
|
||||
python_script = 'main.py'
|
||||
@ -156,7 +151,7 @@ class Slate():
|
||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
||||
script = "\n".join(sh_lines)
|
||||
|
||||
num_jobs = 1
|
||||
num_jobs = self._calc_num_jobs(schedC)
|
||||
|
||||
last_job_idx = num_jobs - 1
|
||||
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
||||
@ -165,33 +160,82 @@ class Slate():
|
||||
job_id = job.submit()
|
||||
print(f'[i] Job submitted to slurm with id {job_id}')
|
||||
|
||||
def run_parallel(self, config):
|
||||
def _fork_processes(self, config, rep_ids):
|
||||
schedC = self.consume(config, 'scheduler')
|
||||
repetitions = self.consume(schedC, 'repetitions')
|
||||
agents_per_job = self.consume(schedC, 'agents_per_job')
|
||||
reps_per_agent = self.consume(schedC, 'reps_per_agent')
|
||||
assert schedC == {}
|
||||
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
||||
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
||||
|
||||
num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent))
|
||||
node_reps = len(rep_ids)
|
||||
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
|
||||
|
||||
if num_p == 1:
|
||||
return [self.run_single(config, max_reps=reps_per_agent)]
|
||||
self._run_single(config, rep_ids=rep_ids, p_ind=0)
|
||||
return
|
||||
|
||||
return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p))
|
||||
procs = []
|
||||
|
||||
def run_single(self, config, max_reps=-1, p_ind=0):
|
||||
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
|
||||
reps_done = 0
|
||||
|
||||
for p in range(num_p):
|
||||
num_reps = min(node_reps - reps_done, reps_per_agent)
|
||||
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
||||
proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
reps_done += num_reps
|
||||
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
|
||||
def _run_process(self, orig_config, rep_ids, p_ind):
|
||||
config = copy.deepcopy(orig_config)
|
||||
if self.consume(config, 'sweep.enable', False):
|
||||
sweepC = self.consume(config, 'sweep')
|
||||
project = config['wandb']['project']
|
||||
sweep_id = wandb.sweep(
|
||||
sweep=sweepC,
|
||||
project=project
|
||||
)
|
||||
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
||||
else:
|
||||
self.consume(config, 'sweep', {})
|
||||
self._run_single(config, rep_ids, p_ind=p_ind)
|
||||
|
||||
def _run_single(self, orig_config, rep_ids, p_ind):
|
||||
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
|
||||
project = self.consume(wandbC, 'project')
|
||||
|
||||
runner = self.runners[runnerName]
|
||||
|
||||
for r in rep_ids:
|
||||
config = copy.deepcopy(orig_config)
|
||||
with wandb.init(
|
||||
project=project,
|
||||
config=config,
|
||||
**wandbC
|
||||
) as run:
|
||||
runner(self, run, config)
|
||||
|
||||
assert config == {}, ('Config was not completely consumed: ', config)
|
||||
orig_config = config
|
||||
|
||||
def _run_from_sweep(self, orig_config, p_ind):
|
||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
|
||||
project = self.consume(wandbC, 'project')
|
||||
|
||||
runner = self.runners[runnerName]
|
||||
|
||||
with wandb.init(
|
||||
project=self.consume(wandbC, 'project'),
|
||||
config=config,
|
||||
project=project,
|
||||
**wandbC
|
||||
) as run:
|
||||
config = copy.deepcopy(orig_config)
|
||||
self.deep_update(config, wandb.config)
|
||||
runner(self, run, config)
|
||||
|
||||
assert config == {}, ('Config was not completely consumed: ', config)
|
||||
assert config == {}, ('Config was not completely consumed: ', config)
|
||||
orig_config = config
|
||||
|
||||
def from_args(self):
|
||||
import argparse
|
||||
@ -201,7 +245,7 @@ class Slate():
|
||||
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
||||
parser.add_argument("-s", "--slurm", action="store_true")
|
||||
parser.add_argument("-w", "--worker", action="store_true")
|
||||
parser.add_argument("-j", "--job_num", default=None)
|
||||
parser.add_argument("-j", "--job_id", default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -212,7 +256,7 @@ class Slate():
|
||||
if args.slurm:
|
||||
self.run_slurm(args.config_file, args.experiment)
|
||||
else:
|
||||
self.run_local(args.config_file, args.experiment, args.job_num)
|
||||
self.run_local(args.config_file, args.experiment, args.job_id)
|
||||
|
||||
|
||||
def print_config_runner(slate, run, config):
|
||||
|
Loading…
Reference in New Issue
Block a user