Multiple Processes fro Slurm

This commit is contained in:
Dominik Moritz Roth 2023-07-07 14:39:38 +02:00
parent be76a5363a
commit 5e8d6d2552

View File

@ -6,7 +6,7 @@ import random
import copy import copy
import collections.abc import collections.abc
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Process
import pdb import pdb
d = pdb.set_trace d = pdb.set_trace
@ -114,38 +114,33 @@ class Slate():
child_keys = '.'.join(keys_arr[1:]) child_keys = '.'.join(keys_arr[1:])
return self.consume(child, child_keys, default=default) return self.consume(child, child_keys, default=default)
def run_local(self, filename, name, job_num=None): def _calc_num_jobs(self, schedulerC):
reps = schedulerC.get('repetitions', 1)
agents_per_job = schedulerC.get('agents_per_job', 1)
reps_per_agent = schedulerC.get('reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedulerC, job_id):
reps = schedulerC.get('repetitions', 1)
num_jobs = self._calc_num_jobs(schedulerC)
reps_for_job = [[]] * num_jobs
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id]
def run_local(self, filename, name, job_id=0):
config = self.load_config(filename, name) config = self.load_config(filename, name)
if self.consume(config, 'sweep.enable', False): schedulerC = copy.deepcopy(config.get('scheduler', {}))
sweepC = self.consume(config, 'sweep') rep_ids = self._reps_for_job(schedulerC, job_id)
project = self.consume(config, 'wandb.project') self._fork_processes(config, rep_ids)
sweep_id = wandb.sweep(
sweep=sweepC,
project=project
)
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
else:
self.consume(config, 'sweep', {})
self.run_single(config)
def _run_from_sweep(self, orig_config, runnerName, project, wandbC):
runner = self.runners[runnerName]
with wandb.init(
project=project,
**wandbC
) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
runner(self, run, config)
assert config == {}, ('Config was not completely consumed: ', config)
def run_slurm(self, filename, name): def run_slurm(self, filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name) config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm') slurmC = self.consume(config, 'slurm')
schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name') s_name = self.consume(slurmC, 'name')
python_script = 'main.py' python_script = 'main.py'
@ -156,7 +151,7 @@ class Slate():
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
script = "\n".join(sh_lines) script = "\n".join(sh_lines)
num_jobs = 1 num_jobs = self._calc_num_jobs(schedC)
last_job_idx = num_jobs - 1 last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
@ -165,33 +160,82 @@ class Slate():
job_id = job.submit() job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}') print(f'[i] Job submitted to slurm with id {job_id}')
def run_parallel(self, config): def _fork_processes(self, config, rep_ids):
schedC = self.consume(config, 'scheduler') schedC = self.consume(config, 'scheduler')
repetitions = self.consume(schedC, 'repetitions') agents_per_job = self.consume(schedC, 'agents_per_job', 1)
agents_per_job = self.consume(schedC, 'agents_per_job') reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent')
assert schedC == {}
num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent)) node_reps = len(rep_ids)
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
if num_p == 1: if num_p == 1:
return [self.run_single(config, max_reps=reps_per_agent)] self._run_single(config, rep_ids=rep_ids, p_ind=0)
return
return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p)) procs = []
def run_single(self, config, max_reps=-1, p_ind=0): reps_done = 0
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
for p in range(num_p):
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
proc.start()
procs.append(proc)
reps_done += num_reps
for proc in procs:
proc.join()
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep')
project = config['wandb']['project']
sweep_id = wandb.sweep(
sweep=sweepC,
project=project
)
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
else:
self.consume(config, 'sweep', {})
self._run_single(config, rep_ids, p_ind=p_ind)
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
with wandb.init(
project=project,
config=config,
**wandbC
) as run:
runner(self, run, config)
assert config == {}, ('Config was not completely consumed: ', config)
orig_config = config
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName] runner = self.runners[runnerName]
with wandb.init( with wandb.init(
project=self.consume(wandbC, 'project'), project=project,
config=config,
**wandbC **wandbC
) as run: ) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
runner(self, run, config) runner(self, run, config)
assert config == {}, ('Config was not completely consumed: ', config) assert config == {}, ('Config was not completely consumed: ', config)
orig_config = config
def from_args(self): def from_args(self):
import argparse import argparse
@ -201,7 +245,7 @@ class Slate():
parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_num", default=None) parser.add_argument("-j", "--job_id", default=0)
args = parser.parse_args() args = parser.parse_args()
@ -212,7 +256,7 @@ class Slate():
if args.slurm: if args.slurm:
self.run_slurm(args.config_file, args.experiment) self.run_slurm(args.config_file, args.experiment)
else: else:
self.run_local(args.config_file, args.experiment, args.job_num) self.run_local(args.config_file, args.experiment, args.job_id)
def print_config_runner(slate, run, config): def print_config_runner(slate, run, config):