diff --git a/slate/slate.py b/slate/slate.py index b7913d5..998f0fe 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -6,7 +6,7 @@ import random import copy import collections.abc from functools import partial -from multiprocessing import Pool +from multiprocessing import Process import pdb d = pdb.set_trace @@ -114,38 +114,33 @@ class Slate(): child_keys = '.'.join(keys_arr[1:]) return self.consume(child, child_keys, default=default) - def run_local(self, filename, name, job_num=None): + def _calc_num_jobs(self, schedulerC): + reps = schedulerC.get('repetitions', 1) + agents_per_job = schedulerC.get('agents_per_job', 1) + reps_per_agent = schedulerC.get('reps_per_agent', 1) + reps_per_job = reps_per_agent * agents_per_job + jobs_needed = math.ceil(reps / reps_per_job) + return jobs_needed + + def _reps_for_job(self, schedulerC, job_id): + reps = schedulerC.get('repetitions', 1) + num_jobs = self._calc_num_jobs(schedulerC) + reps_for_job = [[]] * num_jobs + for i in range(reps): + reps_for_job[i % num_jobs].append(i) + return reps_for_job[job_id] + + def run_local(self, filename, name, job_id=0): config = self.load_config(filename, name) - if self.consume(config, 'sweep.enable', False): - sweepC = self.consume(config, 'sweep') - project = self.consume(config, 'wandb.project') - sweep_id = wandb.sweep( - sweep=sweepC, - project=project - ) - runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) - wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent']) - else: - self.consume(config, 'sweep', {}) - self.run_single(config) - - def _run_from_sweep(self, orig_config, runnerName, project, wandbC): - runner = self.runners[runnerName] - - with wandb.init( - project=project, - **wandbC - ) as run: - config = copy.deepcopy(orig_config) - self.deep_update(config, wandb.config) - runner(self, run, config) - - assert config == {}, ('Config was not completely consumed: ', config) + schedulerC = copy.deepcopy(config.get('scheduler', {})) + rep_ids = self._reps_for_job(schedulerC, job_id) + self._fork_processes(config, rep_ids) def run_slurm(self, filename, name): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = self.load_config(filename, name) slurmC = self.consume(config, 'slurm') + schedC = self.consume(config, 'scheduler') s_name = self.consume(slurmC, 'name') python_script = 'main.py' @@ -156,7 +151,7 @@ class Slate(): sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] script = "\n".join(sh_lines) - num_jobs = 1 + num_jobs = self._calc_num_jobs(schedC) last_job_idx = num_jobs - 1 num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) @@ -165,33 +160,82 @@ class Slate(): job_id = job.submit() print(f'[i] Job submitted to slurm with id {job_id}') - def run_parallel(self, config): + def _fork_processes(self, config, rep_ids): schedC = self.consume(config, 'scheduler') - repetitions = self.consume(schedC, 'repetitions') - agents_per_job = self.consume(schedC, 'agents_per_job') - reps_per_agent = self.consume(schedC, 'reps_per_agent') - assert schedC == {} + agents_per_job = self.consume(schedC, 'agents_per_job', 1) + reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) - num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent)) + node_reps = len(rep_ids) + num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent)) if num_p == 1: - return [self.run_single(config, max_reps=reps_per_agent)] + self._run_single(config, rep_ids=rep_ids, p_ind=0) + return - return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p)) + procs = [] - def run_single(self, config, max_reps=-1, p_ind=0): - runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) + reps_done = 0 + + for p in range(num_p): + num_reps = min(node_reps - reps_done, reps_per_agent) + proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] + proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) + proc.start() + procs.append(proc) + reps_done += num_reps + + for proc in procs: + proc.join() + + def _run_process(self, orig_config, rep_ids, p_ind): + config = copy.deepcopy(orig_config) + if self.consume(config, 'sweep.enable', False): + sweepC = self.consume(config, 'sweep') + project = config['wandb']['project'] + sweep_id = wandb.sweep( + sweep=sweepC, + project=project + ) + wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) + else: + self.consume(config, 'sweep', {}) + self._run_single(config, rep_ids, p_ind=p_ind) + + def _run_single(self, orig_config, rep_ids, p_ind): + print(f'[P{p_ind}] I will work on reps {rep_ids}') + runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}) + project = self.consume(wandbC, 'project') + + runner = self.runners[runnerName] + + for r in rep_ids: + config = copy.deepcopy(orig_config) + with wandb.init( + project=project, + config=config, + **wandbC + ) as run: + runner(self, run, config) + + assert config == {}, ('Config was not completely consumed: ', config) + orig_config = config + + def _run_from_sweep(self, orig_config, p_ind): + runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}) + project = self.consume(wandbC, 'project') runner = self.runners[runnerName] with wandb.init( - project=self.consume(wandbC, 'project'), - config=config, + project=project, **wandbC ) as run: + config = copy.deepcopy(orig_config) + self.deep_update(config, wandb.config) runner(self, run, config) - assert config == {}, ('Config was not completely consumed: ', config) + assert config == {}, ('Config was not completely consumed: ', config) + orig_config = config def from_args(self): import argparse @@ -201,7 +245,7 @@ class Slate(): parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") - parser.add_argument("-j", "--job_num", default=None) + parser.add_argument("-j", "--job_id", default=0) args = parser.parse_args() @@ -212,7 +256,7 @@ class Slate(): if args.slurm: self.run_slurm(args.config_file, args.experiment) else: - self.run_local(args.config_file, args.experiment, args.job_num) + self.run_local(args.config_file, args.experiment, args.job_id) def print_config_runner(slate, run, config):