diff --git a/slate/slate.py b/slate/slate.py index b0c422e..b7913d5 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -1,10 +1,12 @@ import wandb import yaml import os +import math import random import copy import collections.abc from functools import partial +from multiprocessing import Pool import pdb d = pdb.set_trace @@ -163,7 +165,21 @@ class Slate(): job_id = job.submit() print(f'[i] Job submitted to slurm with id {job_id}') - def run_single(self, config): + def run_parallel(self, config): + schedC = self.consume(config, 'scheduler') + repetitions = self.consume(schedC, 'repetitions') + agents_per_job = self.consume(schedC, 'agents_per_job') + reps_per_agent = self.consume(schedC, 'reps_per_agent') + assert schedC == {} + + num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent)) + + if num_p == 1: + return [self.run_single(config, max_reps=reps_per_agent)] + + return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p)) + + def run_single(self, config, max_reps=-1, p_ind=0): runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) runner = self.runners[runnerName]