Support threading
This commit is contained in:
parent
756c096f56
commit
ebe4e04861
@ -7,12 +7,15 @@ import copy
|
|||||||
import collections.abc
|
import collections.abc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
import pdb
|
import pdb
|
||||||
d = pdb.set_trace
|
d = pdb.set_trace
|
||||||
|
|
||||||
REQUIRE_CONFIG_CONSUMED = False
|
REQUIRE_CONFIG_CONSUMED = False
|
||||||
|
|
||||||
|
Parallelization_Primitive = Thread # Process
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pyslurm
|
import pyslurm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -36,7 +39,7 @@ class Slate():
|
|||||||
def load_config(self, filename, name):
|
def load_config(self, filename, name):
|
||||||
config, stack = self._load_config(filename, name)
|
config, stack = self._load_config(filename, name)
|
||||||
print('[i] Merged Configs: ', stack)
|
print('[i] Merged Configs: ', stack)
|
||||||
self.deep_expand_vars(config, config=config)
|
self._config = copy.deepcopy(config)
|
||||||
self.consume(config, 'vars', {})
|
self.consume(config, 'vars', {})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@ -87,7 +90,7 @@ class Slate():
|
|||||||
if isinstance(string, str):
|
if isinstance(string, str):
|
||||||
if string == '{rand}':
|
if string == '{rand}':
|
||||||
return int(random.random()*99999999)
|
return int(random.random()*99999999)
|
||||||
return string.format(**kwargs, rand=int(random.random()*99999999))
|
return string.format(**kwargs, rand=int(random.random()*99999999), srand=srand)
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def apply_nested(self, d, f):
|
def apply_nested(self, d, f):
|
||||||
@ -103,7 +106,7 @@ class Slate():
|
|||||||
def deep_expand_vars(self, dict, **kwargs):
|
def deep_expand_vars(self, dict, **kwargs):
|
||||||
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
||||||
|
|
||||||
def consume(self, conf, key, default=None):
|
def consume(self, conf, key, default=None, **kwargs):
|
||||||
keys_arr = key.split('.')
|
keys_arr = key.split('.')
|
||||||
if len(keys_arr) == 1:
|
if len(keys_arr) == 1:
|
||||||
k = keys_arr[0]
|
k = keys_arr[0]
|
||||||
@ -113,10 +116,14 @@ class Slate():
|
|||||||
val = conf[k]
|
val = conf[k]
|
||||||
if k in conf:
|
if k in conf:
|
||||||
del conf[k]
|
del conf[k]
|
||||||
|
|
||||||
|
while val.find('{') != -1:
|
||||||
|
val = self.expand_vars(val, config=self._config, **kwargs)
|
||||||
|
|
||||||
return val
|
return val
|
||||||
child = conf.get(keys_arr[0], {})
|
child = conf.get(keys_arr[0], {})
|
||||||
child_keys = '.'.join(keys_arr[1:])
|
child_keys = '.'.join(keys_arr[1:])
|
||||||
return self.consume(child, child_keys, default=default)
|
return self.consume(child, child_keys, default=default, **kwargs)
|
||||||
|
|
||||||
def _calc_num_jobs(self, schedulerC):
|
def _calc_num_jobs(self, schedulerC):
|
||||||
reps = schedulerC.get('repetitions', 1)
|
reps = schedulerC.get('repetitions', 1)
|
||||||
@ -185,7 +192,7 @@ class Slate():
|
|||||||
for p in range(num_p):
|
for p in range(num_p):
|
||||||
num_reps = min(node_reps - reps_done, reps_per_agent)
|
num_reps = min(node_reps - reps_done, reps_per_agent)
|
||||||
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
||||||
proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
reps_done += num_reps
|
reps_done += num_reps
|
||||||
|
Loading…
Reference in New Issue
Block a user