Multiple Processes fro Slurm
This commit is contained in:
parent
be76a5363a
commit
5e8d6d2552
164
slate/slate.py
164
slate/slate.py
@ -6,7 +6,7 @@ import random
|
|||||||
import copy
|
import copy
|
||||||
import collections.abc
|
import collections.abc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Process
|
||||||
|
|
||||||
import pdb
|
import pdb
|
||||||
d = pdb.set_trace
|
d = pdb.set_trace
|
||||||
@ -114,22 +114,116 @@ class Slate():
|
|||||||
child_keys = '.'.join(keys_arr[1:])
|
child_keys = '.'.join(keys_arr[1:])
|
||||||
return self.consume(child, child_keys, default=default)
|
return self.consume(child, child_keys, default=default)
|
||||||
|
|
||||||
def run_local(self, filename, name, job_num=None):
|
def _calc_num_jobs(self, schedulerC):
|
||||||
|
reps = schedulerC.get('repetitions', 1)
|
||||||
|
agents_per_job = schedulerC.get('agents_per_job', 1)
|
||||||
|
reps_per_agent = schedulerC.get('reps_per_agent', 1)
|
||||||
|
reps_per_job = reps_per_agent * agents_per_job
|
||||||
|
jobs_needed = math.ceil(reps / reps_per_job)
|
||||||
|
return jobs_needed
|
||||||
|
|
||||||
|
def _reps_for_job(self, schedulerC, job_id):
|
||||||
|
reps = schedulerC.get('repetitions', 1)
|
||||||
|
num_jobs = self._calc_num_jobs(schedulerC)
|
||||||
|
reps_for_job = [[]] * num_jobs
|
||||||
|
for i in range(reps):
|
||||||
|
reps_for_job[i % num_jobs].append(i)
|
||||||
|
return reps_for_job[job_id]
|
||||||
|
|
||||||
|
def run_local(self, filename, name, job_id=0):
|
||||||
config = self.load_config(filename, name)
|
config = self.load_config(filename, name)
|
||||||
|
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||||
|
rep_ids = self._reps_for_job(schedulerC, job_id)
|
||||||
|
self._fork_processes(config, rep_ids)
|
||||||
|
|
||||||
|
def run_slurm(self, filename, name):
|
||||||
|
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
||||||
|
config = self.load_config(filename, name)
|
||||||
|
slurmC = self.consume(config, 'slurm')
|
||||||
|
schedC = self.consume(config, 'scheduler')
|
||||||
|
s_name = self.consume(slurmC, 'name')
|
||||||
|
|
||||||
|
python_script = 'main.py'
|
||||||
|
sh_lines = ['#!/bin/bash']
|
||||||
|
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||||
|
if venv := self.consume(slurmC, 'venv', False):
|
||||||
|
sh_lines += [f'source activate {venv}']
|
||||||
|
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
||||||
|
script = "\n".join(sh_lines)
|
||||||
|
|
||||||
|
num_jobs = self._calc_num_jobs(schedC)
|
||||||
|
|
||||||
|
last_job_idx = num_jobs - 1
|
||||||
|
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
||||||
|
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
||||||
|
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
|
||||||
|
job_id = job.submit()
|
||||||
|
print(f'[i] Job submitted to slurm with id {job_id}')
|
||||||
|
|
||||||
|
def _fork_processes(self, config, rep_ids):
|
||||||
|
schedC = self.consume(config, 'scheduler')
|
||||||
|
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
||||||
|
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
||||||
|
|
||||||
|
node_reps = len(rep_ids)
|
||||||
|
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
|
||||||
|
|
||||||
|
if num_p == 1:
|
||||||
|
self._run_single(config, rep_ids=rep_ids, p_ind=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
procs = []
|
||||||
|
|
||||||
|
reps_done = 0
|
||||||
|
|
||||||
|
for p in range(num_p):
|
||||||
|
num_reps = min(node_reps - reps_done, reps_per_agent)
|
||||||
|
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
||||||
|
proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
||||||
|
proc.start()
|
||||||
|
procs.append(proc)
|
||||||
|
reps_done += num_reps
|
||||||
|
|
||||||
|
for proc in procs:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
def _run_process(self, orig_config, rep_ids, p_ind):
|
||||||
|
config = copy.deepcopy(orig_config)
|
||||||
if self.consume(config, 'sweep.enable', False):
|
if self.consume(config, 'sweep.enable', False):
|
||||||
sweepC = self.consume(config, 'sweep')
|
sweepC = self.consume(config, 'sweep')
|
||||||
project = self.consume(config, 'wandb.project')
|
project = config['wandb']['project']
|
||||||
sweep_id = wandb.sweep(
|
sweep_id = wandb.sweep(
|
||||||
sweep=sweepC,
|
sweep=sweepC,
|
||||||
project=project
|
project=project
|
||||||
)
|
)
|
||||||
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
|
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
||||||
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
|
|
||||||
else:
|
else:
|
||||||
self.consume(config, 'sweep', {})
|
self.consume(config, 'sweep', {})
|
||||||
self.run_single(config)
|
self._run_single(config, rep_ids, p_ind=p_ind)
|
||||||
|
|
||||||
|
def _run_single(self, orig_config, rep_ids, p_ind):
|
||||||
|
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
||||||
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
|
||||||
|
project = self.consume(wandbC, 'project')
|
||||||
|
|
||||||
|
runner = self.runners[runnerName]
|
||||||
|
|
||||||
|
for r in rep_ids:
|
||||||
|
config = copy.deepcopy(orig_config)
|
||||||
|
with wandb.init(
|
||||||
|
project=project,
|
||||||
|
config=config,
|
||||||
|
**wandbC
|
||||||
|
) as run:
|
||||||
|
runner(self, run, config)
|
||||||
|
|
||||||
|
assert config == {}, ('Config was not completely consumed: ', config)
|
||||||
|
orig_config = config
|
||||||
|
|
||||||
|
def _run_from_sweep(self, orig_config, p_ind):
|
||||||
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
|
||||||
|
project = self.consume(wandbC, 'project')
|
||||||
|
|
||||||
def _run_from_sweep(self, orig_config, runnerName, project, wandbC):
|
|
||||||
runner = self.runners[runnerName]
|
runner = self.runners[runnerName]
|
||||||
|
|
||||||
with wandb.init(
|
with wandb.init(
|
||||||
@ -141,57 +235,7 @@ class Slate():
|
|||||||
runner(self, run, config)
|
runner(self, run, config)
|
||||||
|
|
||||||
assert config == {}, ('Config was not completely consumed: ', config)
|
assert config == {}, ('Config was not completely consumed: ', config)
|
||||||
|
orig_config = config
|
||||||
def run_slurm(self, filename, name):
|
|
||||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
|
||||||
config = self.load_config(filename, name)
|
|
||||||
slurmC = self.consume(config, 'slurm')
|
|
||||||
s_name = self.consume(slurmC, 'name')
|
|
||||||
|
|
||||||
python_script = 'main.py'
|
|
||||||
sh_lines = ['#!/bin/bash']
|
|
||||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
|
||||||
if venv := self.consume(slurmC, 'venv', False):
|
|
||||||
sh_lines += [f'source activate {venv}']
|
|
||||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
|
||||||
script = "\n".join(sh_lines)
|
|
||||||
|
|
||||||
num_jobs = 1
|
|
||||||
|
|
||||||
last_job_idx = num_jobs - 1
|
|
||||||
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
|
||||||
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
|
||||||
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
|
|
||||||
job_id = job.submit()
|
|
||||||
print(f'[i] Job submitted to slurm with id {job_id}')
|
|
||||||
|
|
||||||
def run_parallel(self, config):
|
|
||||||
schedC = self.consume(config, 'scheduler')
|
|
||||||
repetitions = self.consume(schedC, 'repetitions')
|
|
||||||
agents_per_job = self.consume(schedC, 'agents_per_job')
|
|
||||||
reps_per_agent = self.consume(schedC, 'reps_per_agent')
|
|
||||||
assert schedC == {}
|
|
||||||
|
|
||||||
num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent))
|
|
||||||
|
|
||||||
if num_p == 1:
|
|
||||||
return [self.run_single(config, max_reps=reps_per_agent)]
|
|
||||||
|
|
||||||
return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p))
|
|
||||||
|
|
||||||
def run_single(self, config, max_reps=-1, p_ind=0):
|
|
||||||
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
|
|
||||||
|
|
||||||
runner = self.runners[runnerName]
|
|
||||||
|
|
||||||
with wandb.init(
|
|
||||||
project=self.consume(wandbC, 'project'),
|
|
||||||
config=config,
|
|
||||||
**wandbC
|
|
||||||
) as run:
|
|
||||||
runner(self, run, config)
|
|
||||||
|
|
||||||
assert config == {}, ('Config was not completely consumed: ', config)
|
|
||||||
|
|
||||||
def from_args(self):
|
def from_args(self):
|
||||||
import argparse
|
import argparse
|
||||||
@ -201,7 +245,7 @@ class Slate():
|
|||||||
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
||||||
parser.add_argument("-s", "--slurm", action="store_true")
|
parser.add_argument("-s", "--slurm", action="store_true")
|
||||||
parser.add_argument("-w", "--worker", action="store_true")
|
parser.add_argument("-w", "--worker", action="store_true")
|
||||||
parser.add_argument("-j", "--job_num", default=None)
|
parser.add_argument("-j", "--job_id", default=0)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -212,7 +256,7 @@ class Slate():
|
|||||||
if args.slurm:
|
if args.slurm:
|
||||||
self.run_slurm(args.config_file, args.experiment)
|
self.run_slurm(args.config_file, args.experiment)
|
||||||
else:
|
else:
|
||||||
self.run_local(args.config_file, args.experiment, args.job_num)
|
self.run_local(args.config_file, args.experiment, args.job_id)
|
||||||
|
|
||||||
|
|
||||||
def print_config_runner(slate, run, config):
|
def print_config_runner(slate, run, config):
|
||||||
|
Loading…
Reference in New Issue
Block a user