Implemented Process Pool

This commit is contained in:
Dominik Moritz Roth 2023-07-07 13:10:06 +02:00
parent a8374d864c
commit be76a5363a

View File

@ -1,10 +1,12 @@
import wandb import wandb
import yaml import yaml
import os import os
import math
import random import random
import copy import copy
import collections.abc import collections.abc
from functools import partial from functools import partial
from multiprocessing import Pool
import pdb import pdb
d = pdb.set_trace d = pdb.set_trace
@ -163,7 +165,21 @@ class Slate():
job_id = job.submit() job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}') print(f'[i] Job submitted to slurm with id {job_id}')
def run_single(self, config): def run_parallel(self, config):
schedC = self.consume(config, 'scheduler')
repetitions = self.consume(schedC, 'repetitions')
agents_per_job = self.consume(schedC, 'agents_per_job')
reps_per_agent = self.consume(schedC, 'reps_per_agent')
assert schedC == {}
num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent))
if num_p == 1:
return [self.run_single(config, max_reps=reps_per_agent)]
return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p))
def run_single(self, config, max_reps=-1, p_ind=0):
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
runner = self.runners[runnerName] runner = self.runners[runnerName]