Support threading

This commit is contained in:
Dominik Moritz Roth 2023-07-09 16:12:38 +02:00
parent 756c096f56
commit ebe4e04861

View File

@ -7,12 +7,15 @@ import copy
import collections.abc
from functools import partial
from multiprocessing import Process
from threading import Thread
import pdb
d = pdb.set_trace
REQUIRE_CONFIG_CONSUMED = False
Parallelization_Primitive = Thread # Process
try:
import pyslurm
except ImportError:
@ -36,7 +39,7 @@ class Slate():
def load_config(self, filename, name):
config, stack = self._load_config(filename, name)
print('[i] Merged Configs: ', stack)
self.deep_expand_vars(config, config=config)
self._config = copy.deepcopy(config)
self.consume(config, 'vars', {})
return config
@ -87,7 +90,7 @@ class Slate():
if isinstance(string, str):
if string == '{rand}':
return int(random.random()*99999999)
return string.format(**kwargs, rand=int(random.random()*99999999))
return string.format(**kwargs, rand=int(random.random()*99999999), srand=srand)
return string
def apply_nested(self, d, f):
@ -103,7 +106,7 @@ class Slate():
def deep_expand_vars(self, dict, **kwargs):
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
def consume(self, conf, key, default=None):
def consume(self, conf, key, default=None, **kwargs):
keys_arr = key.split('.')
if len(keys_arr) == 1:
k = keys_arr[0]
@ -113,10 +116,14 @@ class Slate():
val = conf[k]
if k in conf:
del conf[k]
while val.find('{') != -1:
val = self.expand_vars(val, config=self._config, **kwargs)
return val
child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:])
return self.consume(child, child_keys, default=default)
return self.consume(child, child_keys, default=default, **kwargs)
def _calc_num_jobs(self, schedulerC):
reps = schedulerC.get('repetitions', 1)
@ -185,7 +192,7 @@ class Slate():
for p in range(num_p):
num_reps = min(node_reps - reps_done, reps_per_agent)
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
proc.start()
procs.append(proc)
reps_done += num_reps