Multiple Processes fro Slurm

This commit is contained in:
Dominik Moritz Roth 2023-07-07 14:39:38 +02:00
parent be76a5363a
commit 5e8d6d2552

View File

@ -6,7 +6,7 @@ import random
import copy import copy
import collections.abc import collections.abc
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Process
import pdb import pdb
d = pdb.set_trace d = pdb.set_trace
@ -114,22 +114,116 @@ class Slate():
child_keys = '.'.join(keys_arr[1:]) child_keys = '.'.join(keys_arr[1:])
return self.consume(child, child_keys, default=default) return self.consume(child, child_keys, default=default)
def run_local(self, filename, name, job_num=None): def _calc_num_jobs(self, schedulerC):
reps = schedulerC.get('repetitions', 1)
agents_per_job = schedulerC.get('agents_per_job', 1)
reps_per_agent = schedulerC.get('reps_per_agent', 1)
reps_per_job = reps_per_agent * agents_per_job
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedulerC, job_id):
reps = schedulerC.get('repetitions', 1)
num_jobs = self._calc_num_jobs(schedulerC)
reps_for_job = [[]] * num_jobs
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id]
def run_local(self, filename, name, job_id=0):
config = self.load_config(filename, name) config = self.load_config(filename, name)
schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id)
self._fork_processes(config, rep_ids)
def run_slurm(self, filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm')
schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name')
python_script = 'main.py'
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC)
last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
array = f'0-{last_job_idx}%{num_parallel_jobs}'
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}')
def _fork_processes(self, config, rep_ids):
schedC = self.consume(config, 'scheduler')
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
node_reps = len(rep_ids)
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
if num_p == 1:
self._run_single(config, rep_ids=rep_ids, p_ind=0)
return
procs = []
reps_done = 0
for p in range(num_p):
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
proc.start()
procs.append(proc)
reps_done += num_reps
for proc in procs:
proc.join()
def _run_process(self, orig_config, rep_ids, p_ind):
config = copy.deepcopy(orig_config)
if self.consume(config, 'sweep.enable', False): if self.consume(config, 'sweep.enable', False):
sweepC = self.consume(config, 'sweep') sweepC = self.consume(config, 'sweep')
project = self.consume(config, 'wandb.project') project = config['wandb']['project']
sweep_id = wandb.sweep( sweep_id = wandb.sweep(
sweep=sweepC, sweep=sweepC,
project=project project=project
) )
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
else: else:
self.consume(config, 'sweep', {}) self.consume(config, 'sweep', {})
self.run_single(config) self._run_single(config, rep_ids, p_ind=p_ind)
def _run_single(self, orig_config, rep_ids, p_ind):
print(f'[P{p_ind}] I will work on reps {rep_ids}')
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
with wandb.init(
project=project,
config=config,
**wandbC
) as run:
runner(self, run, config)
assert config == {}, ('Config was not completely consumed: ', config)
orig_config = config
def _run_from_sweep(self, orig_config, p_ind):
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
project = self.consume(wandbC, 'project')
def _run_from_sweep(self, orig_config, runnerName, project, wandbC):
runner = self.runners[runnerName] runner = self.runners[runnerName]
with wandb.init( with wandb.init(
@ -141,57 +235,7 @@ class Slate():
runner(self, run, config) runner(self, run, config)
assert config == {}, ('Config was not completely consumed: ', config) assert config == {}, ('Config was not completely consumed: ', config)
orig_config = config
def run_slurm(self, filename, name):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm')
s_name = self.consume(slurmC, 'name')
python_script = 'main.py'
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
script = "\n".join(sh_lines)
num_jobs = 1
last_job_idx = num_jobs - 1
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
array = f'0-{last_job_idx}%{num_parallel_jobs}'
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}')
def run_parallel(self, config):
schedC = self.consume(config, 'scheduler')
repetitions = self.consume(schedC, 'repetitions')
agents_per_job = self.consume(schedC, 'agents_per_job')
reps_per_agent = self.consume(schedC, 'reps_per_agent')
assert schedC == {}
num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent))
if num_p == 1:
return [self.run_single(config, max_reps=reps_per_agent)]
return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p))
def run_single(self, config, max_reps=-1, p_ind=0):
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
runner = self.runners[runnerName]
with wandb.init(
project=self.consume(wandbC, 'project'),
config=config,
**wandbC
) as run:
runner(self, run, config)
assert config == {}, ('Config was not completely consumed: ', config)
def from_args(self): def from_args(self):
import argparse import argparse
@ -201,7 +245,7 @@ class Slate():
parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_num", default=None) parser.add_argument("-j", "--job_id", default=0)
args = parser.parse_args() args = parser.parse_args()
@ -212,7 +256,7 @@ class Slate():
if args.slurm: if args.slurm:
self.run_slurm(args.config_file, args.experiment) self.run_slurm(args.config_file, args.experiment)
else: else:
self.run_local(args.config_file, args.experiment, args.job_num) self.run_local(args.config_file, args.experiment, args.job_id)
def print_config_runner(slate, run, config): def print_config_runner(slate, run, config):