From ebe4e04861941315e494a1edf9760e9abd5203e5 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 9 Jul 2023 16:12:38 +0200 Subject: [PATCH] Support threading --- slate/slate.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/slate/slate.py b/slate/slate.py index 3d10a8a..1555dea 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -7,12 +7,15 @@ import copy import collections.abc from functools import partial from multiprocessing import Process +from threading import Thread import pdb d = pdb.set_trace REQUIRE_CONFIG_CONSUMED = False +Parallelization_Primitive = Thread # Process + try: import pyslurm except ImportError: @@ -36,7 +39,7 @@ class Slate(): def load_config(self, filename, name): config, stack = self._load_config(filename, name) print('[i] Merged Configs: ', stack) - self.deep_expand_vars(config, config=config) + self._config = copy.deepcopy(config) self.consume(config, 'vars', {}) return config @@ -87,7 +90,7 @@ class Slate(): if isinstance(string, str): if string == '{rand}': return int(random.random()*99999999) - return string.format(**kwargs, rand=int(random.random()*99999999)) + return string.format(**kwargs, rand=int(random.random()*99999999), srand=srand) return string def apply_nested(self, d, f): @@ -103,7 +106,7 @@ class Slate(): def deep_expand_vars(self, dict, **kwargs): self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) - def consume(self, conf, key, default=None): + def consume(self, conf, key, default=None, **kwargs): keys_arr = key.split('.') if len(keys_arr) == 1: k = keys_arr[0] @@ -113,10 +116,14 @@ class Slate(): val = conf[k] if k in conf: del conf[k] + + while val.find('{') != -1: + val = self.expand_vars(val, config=self._config, **kwargs) + return val child = conf.get(keys_arr[0], {}) child_keys = '.'.join(keys_arr[1:]) - return self.consume(child, child_keys, default=default) + return self.consume(child, child_keys, default=default, **kwargs) def _calc_num_jobs(self, schedulerC): reps = schedulerC.get('repetitions', 1) @@ -185,7 +192,7 @@ class Slate(): for p in range(num_p): num_reps = min(node_reps - reps_done, reps_per_agent) proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] - proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) + proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) proc.start() procs.append(proc) reps_done += num_reps