Implemented Process Pool

This commit is contained in:
Dominik Moritz Roth 2023-07-07 13:10:06 +02:00
parent a8374d864c
commit be76a5363a

View File

@ -1,10 +1,12 @@
import wandb
import yaml
import os
import math
import random
import copy
import collections.abc
from functools import partial
from multiprocessing import Pool
import pdb
d = pdb.set_trace
@ -163,7 +165,21 @@ class Slate():
job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}')
def run_single(self, config):
def run_parallel(self, config):
schedC = self.consume(config, 'scheduler')
repetitions = self.consume(schedC, 'repetitions')
agents_per_job = self.consume(schedC, 'agents_per_job')
reps_per_agent = self.consume(schedC, 'reps_per_agent')
assert schedC == {}
num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent))
if num_p == 1:
return [self.run_single(config, max_reps=reps_per_agent)]
return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p))
def run_single(self, config, max_reps=-1, p_ind=0):
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
runner = self.runners[runnerName]