Support threading
This commit is contained in:
		
							parent
							
								
									756c096f56
								
							
						
					
					
						commit
						ebe4e04861
					
				| @ -7,12 +7,15 @@ import copy | |||||||
| import collections.abc | import collections.abc | ||||||
| from functools import partial | from functools import partial | ||||||
| from multiprocessing import Process | from multiprocessing import Process | ||||||
|  | from threading import Thread | ||||||
| 
 | 
 | ||||||
| import pdb | import pdb | ||||||
| d = pdb.set_trace | d = pdb.set_trace | ||||||
| 
 | 
 | ||||||
| REQUIRE_CONFIG_CONSUMED = False | REQUIRE_CONFIG_CONSUMED = False | ||||||
| 
 | 
 | ||||||
|  | Parallelization_Primitive = Thread  # Process | ||||||
|  | 
 | ||||||
| try: | try: | ||||||
|     import pyslurm |     import pyslurm | ||||||
| except ImportError: | except ImportError: | ||||||
| @ -36,7 +39,7 @@ class Slate(): | |||||||
|     def load_config(self, filename, name): |     def load_config(self, filename, name): | ||||||
|         config, stack = self._load_config(filename, name) |         config, stack = self._load_config(filename, name) | ||||||
|         print('[i] Merged Configs: ', stack) |         print('[i] Merged Configs: ', stack) | ||||||
|         self.deep_expand_vars(config, config=config) |         self._config = copy.deepcopy(config) | ||||||
|         self.consume(config, 'vars', {}) |         self.consume(config, 'vars', {}) | ||||||
|         return config |         return config | ||||||
| 
 | 
 | ||||||
| @ -87,7 +90,7 @@ class Slate(): | |||||||
|         if isinstance(string, str): |         if isinstance(string, str): | ||||||
|             if string == '{rand}': |             if string == '{rand}': | ||||||
|                 return int(random.random()*99999999) |                 return int(random.random()*99999999) | ||||||
|             return string.format(**kwargs, rand=int(random.random()*99999999)) |             return string.format(**kwargs, rand=int(random.random()*99999999), srand=srand) | ||||||
|         return string |         return string | ||||||
| 
 | 
 | ||||||
|     def apply_nested(self, d, f): |     def apply_nested(self, d, f): | ||||||
| @ -103,7 +106,7 @@ class Slate(): | |||||||
|     def deep_expand_vars(self, dict, **kwargs): |     def deep_expand_vars(self, dict, **kwargs): | ||||||
|         self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) |         self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs)) | ||||||
| 
 | 
 | ||||||
|     def consume(self, conf, key, default=None): |     def consume(self, conf, key, default=None, **kwargs): | ||||||
|         keys_arr = key.split('.') |         keys_arr = key.split('.') | ||||||
|         if len(keys_arr) == 1: |         if len(keys_arr) == 1: | ||||||
|             k = keys_arr[0] |             k = keys_arr[0] | ||||||
| @ -113,10 +116,14 @@ class Slate(): | |||||||
|                 val = conf[k] |                 val = conf[k] | ||||||
|             if k in conf: |             if k in conf: | ||||||
|                 del conf[k] |                 del conf[k] | ||||||
|  | 
 | ||||||
|  |             while val.find('{') != -1: | ||||||
|  |                 val = self.expand_vars(val, config=self._config, **kwargs) | ||||||
|  | 
 | ||||||
|             return val |             return val | ||||||
|         child = conf.get(keys_arr[0], {}) |         child = conf.get(keys_arr[0], {}) | ||||||
|         child_keys = '.'.join(keys_arr[1:]) |         child_keys = '.'.join(keys_arr[1:]) | ||||||
|         return self.consume(child, child_keys, default=default) |         return self.consume(child, child_keys, default=default, **kwargs) | ||||||
| 
 | 
 | ||||||
|     def _calc_num_jobs(self, schedulerC): |     def _calc_num_jobs(self, schedulerC): | ||||||
|         reps = schedulerC.get('repetitions', 1) |         reps = schedulerC.get('repetitions', 1) | ||||||
| @ -185,7 +192,7 @@ class Slate(): | |||||||
|         for p in range(num_p): |         for p in range(num_p): | ||||||
|             num_reps = min(node_reps - reps_done, reps_per_agent) |             num_reps = min(node_reps - reps_done, reps_per_agent) | ||||||
|             proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] |             proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] | ||||||
|             proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) |             proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) | ||||||
|             proc.start() |             proc.start() | ||||||
|             procs.append(proc) |             procs.append(proc) | ||||||
|             reps_done += num_reps |             reps_done += num_reps | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user