From be76a5363a451e7f9c3026c8bab1ca0fca6987d9 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 7 Jul 2023 13:10:06 +0200 Subject: [PATCH] Implemented Process Pool --- slate/slate.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/slate/slate.py b/slate/slate.py index b0c422e..b7913d5 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -1,10 +1,12 @@ import wandb import yaml import os +import math import random import copy import collections.abc from functools import partial +from multiprocessing import Pool import pdb d = pdb.set_trace @@ -163,7 +165,21 @@ class Slate(): job_id = job.submit() print(f'[i] Job submitted to slurm with id {job_id}') - def run_single(self, config): + def run_parallel(self, config): + schedC = self.consume(config, 'scheduler') + repetitions = self.consume(schedC, 'repetitions') + agents_per_job = self.consume(schedC, 'agents_per_job') + reps_per_agent = self.consume(schedC, 'reps_per_agent') + assert schedC == {} + + num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent)) + + if num_p == 1: + return [self.run_single(config, max_reps=reps_per_agent)] + + return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p)) + + def run_single(self, config, max_reps=-1, p_ind=0): runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) runner = self.runners[runnerName]