Multiple Processes fro Slurm
This commit is contained in:
		
							parent
							
								
									be76a5363a
								
							
						
					
					
						commit
						5e8d6d2552
					
				
							
								
								
									
										164
									
								
								slate/slate.py
									
									
									
									
									
								
							
							
						
						
									
										164
									
								
								slate/slate.py
									
									
									
									
									
								
							| @ -6,7 +6,7 @@ import random | ||||
| import copy | ||||
| import collections.abc | ||||
| from functools import partial | ||||
| from multiprocessing import Pool | ||||
| from multiprocessing import Process | ||||
| 
 | ||||
| import pdb | ||||
| d = pdb.set_trace | ||||
| @ -114,22 +114,116 @@ class Slate(): | ||||
|         child_keys = '.'.join(keys_arr[1:]) | ||||
|         return self.consume(child, child_keys, default=default) | ||||
| 
 | ||||
|     def run_local(self, filename, name, job_num=None): | ||||
|     def _calc_num_jobs(self, schedulerC): | ||||
|         reps = schedulerC.get('repetitions', 1) | ||||
|         agents_per_job = schedulerC.get('agents_per_job', 1) | ||||
|         reps_per_agent = schedulerC.get('reps_per_agent', 1) | ||||
|         reps_per_job = reps_per_agent * agents_per_job | ||||
|         jobs_needed = math.ceil(reps / reps_per_job) | ||||
|         return jobs_needed | ||||
| 
 | ||||
|     def _reps_for_job(self, schedulerC, job_id): | ||||
|         reps = schedulerC.get('repetitions', 1) | ||||
|         num_jobs = self._calc_num_jobs(schedulerC) | ||||
|         reps_for_job = [[]] * num_jobs | ||||
|         for i in range(reps): | ||||
|             reps_for_job[i % num_jobs].append(i) | ||||
|         return reps_for_job[job_id] | ||||
| 
 | ||||
|     def run_local(self, filename, name, job_id=0): | ||||
|         config = self.load_config(filename, name) | ||||
|         schedulerC = copy.deepcopy(config.get('scheduler', {})) | ||||
|         rep_ids = self._reps_for_job(schedulerC, job_id) | ||||
|         self._fork_processes(config, rep_ids) | ||||
| 
 | ||||
|     def run_slurm(self, filename, name): | ||||
|         assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' | ||||
|         config = self.load_config(filename, name) | ||||
|         slurmC = self.consume(config, 'slurm') | ||||
|         schedC = self.consume(config, 'scheduler') | ||||
|         s_name = self.consume(slurmC, 'name') | ||||
| 
 | ||||
|         python_script = 'main.py' | ||||
|         sh_lines = ['#!/bin/bash'] | ||||
|         sh_lines += self.consume(slurmC, 'sh_lines', []) | ||||
|         if venv := self.consume(slurmC, 'venv', False): | ||||
|             sh_lines += [f'source activate {venv}'] | ||||
|         sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] | ||||
|         script = "\n".join(sh_lines) | ||||
| 
 | ||||
|         num_jobs = self._calc_num_jobs(schedC) | ||||
| 
 | ||||
|         last_job_idx = num_jobs - 1 | ||||
|         num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) | ||||
|         array = f'0-{last_job_idx}%{num_parallel_jobs}' | ||||
|         job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) | ||||
|         job_id = job.submit() | ||||
|         print(f'[i] Job submitted to slurm with id {job_id}') | ||||
| 
 | ||||
|     def _fork_processes(self, config, rep_ids): | ||||
|         schedC = self.consume(config, 'scheduler') | ||||
|         agents_per_job = self.consume(schedC, 'agents_per_job', 1) | ||||
|         reps_per_agent = self.consume(schedC, 'reps_per_agent', 1) | ||||
| 
 | ||||
|         node_reps = len(rep_ids) | ||||
|         num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent)) | ||||
| 
 | ||||
|         if num_p == 1: | ||||
|             self._run_single(config, rep_ids=rep_ids, p_ind=0) | ||||
|             return | ||||
| 
 | ||||
|         procs = [] | ||||
| 
 | ||||
|         reps_done = 0 | ||||
| 
 | ||||
|         for p in range(num_p): | ||||
|             num_reps = min(node_reps - reps_done, reps_per_agent) | ||||
|             proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))] | ||||
|             proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p)) | ||||
|             proc.start() | ||||
|             procs.append(proc) | ||||
|             reps_done += num_reps | ||||
| 
 | ||||
|         for proc in procs: | ||||
|             proc.join() | ||||
| 
 | ||||
|     def _run_process(self, orig_config, rep_ids, p_ind): | ||||
|         config = copy.deepcopy(orig_config) | ||||
|         if self.consume(config, 'sweep.enable', False): | ||||
|             sweepC = self.consume(config, 'sweep') | ||||
|             project = self.consume(config, 'wandb.project') | ||||
|             project = config['wandb']['project'] | ||||
|             sweep_id = wandb.sweep( | ||||
|                 sweep=sweepC, | ||||
|                 project=project | ||||
|             ) | ||||
|             runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) | ||||
|             wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent']) | ||||
|             wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids)) | ||||
|         else: | ||||
|             self.consume(config, 'sweep', {}) | ||||
|             self.run_single(config) | ||||
|             self._run_single(config, rep_ids, p_ind=p_ind) | ||||
| 
 | ||||
|     def _run_single(self, orig_config, rep_ids, p_ind): | ||||
|         print(f'[P{p_ind}] I will work on reps {rep_ids}') | ||||
|         runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}) | ||||
|         project = self.consume(wandbC, 'project') | ||||
| 
 | ||||
|         runner = self.runners[runnerName] | ||||
| 
 | ||||
|         for r in rep_ids: | ||||
|             config = copy.deepcopy(orig_config) | ||||
|             with wandb.init( | ||||
|                 project=project, | ||||
|                 config=config, | ||||
|                 **wandbC | ||||
|             ) as run: | ||||
|                 runner(self, run, config) | ||||
| 
 | ||||
|             assert config == {}, ('Config was not completely consumed: ', config) | ||||
|         orig_config = config | ||||
| 
 | ||||
|     def _run_from_sweep(self, orig_config, p_ind): | ||||
|         runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}) | ||||
|         project = self.consume(wandbC, 'project') | ||||
| 
 | ||||
|     def _run_from_sweep(self, orig_config, runnerName, project, wandbC): | ||||
|         runner = self.runners[runnerName] | ||||
| 
 | ||||
|         with wandb.init( | ||||
| @ -141,57 +235,7 @@ class Slate(): | ||||
|             runner(self, run, config) | ||||
| 
 | ||||
|             assert config == {}, ('Config was not completely consumed: ', config) | ||||
| 
 | ||||
|     def run_slurm(self, filename, name): | ||||
|         assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' | ||||
|         config = self.load_config(filename, name) | ||||
|         slurmC = self.consume(config, 'slurm') | ||||
|         s_name = self.consume(slurmC, 'name') | ||||
| 
 | ||||
|         python_script = 'main.py' | ||||
|         sh_lines = ['#!/bin/bash'] | ||||
|         sh_lines += self.consume(slurmC, 'sh_lines', []) | ||||
|         if venv := self.consume(slurmC, 'venv', False): | ||||
|             sh_lines += [f'source activate {venv}'] | ||||
|         sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID'] | ||||
|         script = "\n".join(sh_lines) | ||||
| 
 | ||||
|         num_jobs = 1 | ||||
| 
 | ||||
|         last_job_idx = num_jobs - 1 | ||||
|         num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs) | ||||
|         array = f'0-{last_job_idx}%{num_parallel_jobs}' | ||||
|         job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) | ||||
|         job_id = job.submit() | ||||
|         print(f'[i] Job submitted to slurm with id {job_id}') | ||||
| 
 | ||||
|     def run_parallel(self, config): | ||||
|         schedC = self.consume(config, 'scheduler') | ||||
|         repetitions = self.consume(schedC, 'repetitions') | ||||
|         agents_per_job = self.consume(schedC, 'agents_per_job') | ||||
|         reps_per_agent = self.consume(schedC, 'reps_per_agent') | ||||
|         assert schedC == {} | ||||
| 
 | ||||
|         num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent)) | ||||
| 
 | ||||
|         if num_p == 1: | ||||
|             return [self.run_single(config, max_reps=reps_per_agent)] | ||||
| 
 | ||||
|         return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p)) | ||||
| 
 | ||||
|     def run_single(self, config, max_reps=-1, p_ind=0): | ||||
|         runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {}) | ||||
| 
 | ||||
|         runner = self.runners[runnerName] | ||||
| 
 | ||||
|         with wandb.init( | ||||
|             project=self.consume(wandbC, 'project'), | ||||
|             config=config, | ||||
|             **wandbC | ||||
|         ) as run: | ||||
|             runner(self, run, config) | ||||
| 
 | ||||
|         assert config == {}, ('Config was not completely consumed: ', config) | ||||
|         orig_config = config | ||||
| 
 | ||||
|     def from_args(self): | ||||
|         import argparse | ||||
| @ -201,7 +245,7 @@ class Slate(): | ||||
|         parser.add_argument("experiment", nargs='?', default='DEFAULT') | ||||
|         parser.add_argument("-s", "--slurm", action="store_true") | ||||
|         parser.add_argument("-w", "--worker", action="store_true") | ||||
|         parser.add_argument("-j", "--job_num", default=None) | ||||
|         parser.add_argument("-j", "--job_id", default=0) | ||||
| 
 | ||||
|         args = parser.parse_args() | ||||
| 
 | ||||
| @ -212,7 +256,7 @@ class Slate(): | ||||
|         if args.slurm: | ||||
|             self.run_slurm(args.config_file, args.experiment) | ||||
|         else: | ||||
|             self.run_local(args.config_file, args.experiment, args.job_num) | ||||
|             self.run_local(args.config_file, args.experiment, args.job_id) | ||||
| 
 | ||||
| 
 | ||||
| def print_config_runner(slate, run, config): | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user