Support threading
This commit is contained in:
parent
756c096f56
commit
ebe4e04861
@ -7,12 +7,15 @@ import copy
|
||||
import collections.abc
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from threading import Thread
|
||||
|
||||
import pdb
|
||||
d = pdb.set_trace
|
||||
|
||||
REQUIRE_CONFIG_CONSUMED = False
|
||||
|
||||
Parallelization_Primitive = Thread # Process
|
||||
|
||||
try:
|
||||
import pyslurm
|
||||
except ImportError:
|
||||
@ -36,7 +39,7 @@ class Slate():
|
||||
def load_config(self, filename, name):
|
||||
config, stack = self._load_config(filename, name)
|
||||
print('[i] Merged Configs: ', stack)
|
||||
self.deep_expand_vars(config, config=config)
|
||||
self._config = copy.deepcopy(config)
|
||||
self.consume(config, 'vars', {})
|
||||
return config
|
||||
|
||||
@ -87,7 +90,7 @@ class Slate():
|
||||
if isinstance(string, str):
|
||||
if string == '{rand}':
|
||||
return int(random.random()*99999999)
|
||||
return string.format(**kwargs, rand=int(random.random()*99999999))
|
||||
return string.format(**kwargs, rand=int(random.random()*99999999), srand=srand)
|
||||
return string
|
||||
|
||||
def apply_nested(self, d, f):
|
||||
@ -103,7 +106,7 @@ class Slate():
|
||||
def deep_expand_vars(self, dict, **kwargs):
|
||||
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
||||
|
||||
def consume(self, conf, key, default=None):
|
||||
def consume(self, conf, key, default=None, **kwargs):
|
||||
keys_arr = key.split('.')
|
||||
if len(keys_arr) == 1:
|
||||
k = keys_arr[0]
|
||||
@ -113,10 +116,14 @@ class Slate():
|
||||
val = conf[k]
|
||||
if k in conf:
|
||||
del conf[k]
|
||||
|
||||
while val.find('{') != -1:
|
||||
val = self.expand_vars(val, config=self._config, **kwargs)
|
||||
|
||||
return val
|
||||
child = conf.get(keys_arr[0], {})
|
||||
child_keys = '.'.join(keys_arr[1:])
|
||||
return self.consume(child, child_keys, default=default)
|
||||
return self.consume(child, child_keys, default=default, **kwargs)
|
||||
|
||||
def _calc_num_jobs(self, schedulerC):
|
||||
reps = schedulerC.get('repetitions', 1)
|
||||
@ -185,7 +192,7 @@ class Slate():
|
||||
for p in range(num_p):
|
||||
num_reps = min(node_reps - reps_done, reps_per_agent)
|
||||
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
||||
proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
||||
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
reps_done += num_reps
|
||||
|
Loading…
Reference in New Issue
Block a user