Multiple Processes fro Slurm
This commit is contained in:
		
							parent
							
								
									be76a5363a
								
							
						
					
					
						commit
						5e8d6d2552
					
				
							
								
								
									
										130
									
								
								slate/slate.py
									
									
									
									
									
								
							
							
						
						
									
										130
									
								
								slate/slate.py
									
									
									
									
									
								
							@ -6,7 +6,7 @@ import random
 | 
				
			|||||||
import copy
 | 
					import copy
 | 
				
			||||||
import collections.abc
 | 
					import collections.abc
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
from multiprocessing import Pool
 | 
					from multiprocessing import Process
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pdb
 | 
					import pdb
 | 
				
			||||||
d = pdb.set_trace
 | 
					d = pdb.set_trace
 | 
				
			||||||
@ -114,38 +114,33 @@ class Slate():
 | 
				
			|||||||
        child_keys = '.'.join(keys_arr[1:])
 | 
					        child_keys = '.'.join(keys_arr[1:])
 | 
				
			||||||
        return self.consume(child, child_keys, default=default)
 | 
					        return self.consume(child, child_keys, default=default)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run_local(self, filename, name, job_num=None):
 | 
					    def _calc_num_jobs(self, schedulerC):
 | 
				
			||||||
 | 
					        reps = schedulerC.get('repetitions', 1)
 | 
				
			||||||
 | 
					        agents_per_job = schedulerC.get('agents_per_job', 1)
 | 
				
			||||||
 | 
					        reps_per_agent = schedulerC.get('reps_per_agent', 1)
 | 
				
			||||||
 | 
					        reps_per_job = reps_per_agent * agents_per_job
 | 
				
			||||||
 | 
					        jobs_needed = math.ceil(reps / reps_per_job)
 | 
				
			||||||
 | 
					        return jobs_needed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _reps_for_job(self, schedulerC, job_id):
 | 
				
			||||||
 | 
					        reps = schedulerC.get('repetitions', 1)
 | 
				
			||||||
 | 
					        num_jobs = self._calc_num_jobs(schedulerC)
 | 
				
			||||||
 | 
					        reps_for_job = [[]] * num_jobs
 | 
				
			||||||
 | 
					        for i in range(reps):
 | 
				
			||||||
 | 
					            reps_for_job[i % num_jobs].append(i)
 | 
				
			||||||
 | 
					        return reps_for_job[job_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_local(self, filename, name, job_id=0):
 | 
				
			||||||
        config = self.load_config(filename, name)
 | 
					        config = self.load_config(filename, name)
 | 
				
			||||||
        if self.consume(config, 'sweep.enable', False):
 | 
					        schedulerC = copy.deepcopy(config.get('scheduler', {}))
 | 
				
			||||||
            sweepC = self.consume(config, 'sweep')
 | 
					        rep_ids = self._reps_for_job(schedulerC, job_id)
 | 
				
			||||||
            project = self.consume(config, 'wandb.project')
 | 
					        self._fork_processes(config, rep_ids)
 | 
				
			||||||
            sweep_id = wandb.sweep(
 | 
					 | 
				
			||||||
                sweep=sweepC,
 | 
					 | 
				
			||||||
                project=project
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
 | 
					 | 
				
			||||||
            wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.consume(config, 'sweep', {})
 | 
					 | 
				
			||||||
            self.run_single(config)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _run_from_sweep(self, orig_config, runnerName, project, wandbC):
 | 
					 | 
				
			||||||
        runner = self.runners[runnerName]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with wandb.init(
 | 
					 | 
				
			||||||
            project=project,
 | 
					 | 
				
			||||||
            **wandbC
 | 
					 | 
				
			||||||
        ) as run:
 | 
					 | 
				
			||||||
            config = copy.deepcopy(orig_config)
 | 
					 | 
				
			||||||
            self.deep_update(config, wandb.config)
 | 
					 | 
				
			||||||
            runner(self, run, config)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        assert config == {}, ('Config was not completely consumed: ', config)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run_slurm(self, filename, name):
 | 
					    def run_slurm(self, filename, name):
 | 
				
			||||||
        assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
 | 
					        assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
 | 
				
			||||||
        config = self.load_config(filename, name)
 | 
					        config = self.load_config(filename, name)
 | 
				
			||||||
        slurmC = self.consume(config, 'slurm')
 | 
					        slurmC = self.consume(config, 'slurm')
 | 
				
			||||||
 | 
					        schedC = self.consume(config, 'scheduler')
 | 
				
			||||||
        s_name = self.consume(slurmC, 'name')
 | 
					        s_name = self.consume(slurmC, 'name')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        python_script = 'main.py'
 | 
					        python_script = 'main.py'
 | 
				
			||||||
@ -156,7 +151,7 @@ class Slate():
 | 
				
			|||||||
        sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
 | 
					        sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
 | 
				
			||||||
        script = "\n".join(sh_lines)
 | 
					        script = "\n".join(sh_lines)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        num_jobs = 1
 | 
					        num_jobs = self._calc_num_jobs(schedC)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        last_job_idx = num_jobs - 1
 | 
					        last_job_idx = num_jobs - 1
 | 
				
			||||||
        num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
 | 
					        num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
 | 
				
			||||||
@ -165,33 +160,82 @@ class Slate():
 | 
				
			|||||||
        job_id = job.submit()
 | 
					        job_id = job.submit()
 | 
				
			||||||
        print(f'[i] Job submitted to slurm with id {job_id}')
 | 
					        print(f'[i] Job submitted to slurm with id {job_id}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run_parallel(self, config):
 | 
					    def _fork_processes(self, config, rep_ids):
 | 
				
			||||||
        schedC = self.consume(config, 'scheduler')
 | 
					        schedC = self.consume(config, 'scheduler')
 | 
				
			||||||
        repetitions = self.consume(schedC, 'repetitions')
 | 
					        agents_per_job = self.consume(schedC, 'agents_per_job', 1)
 | 
				
			||||||
        agents_per_job = self.consume(schedC, 'agents_per_job')
 | 
					        reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
 | 
				
			||||||
        reps_per_agent = self.consume(schedC, 'reps_per_agent')
 | 
					 | 
				
			||||||
        assert schedC == {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        num_p = min(agents_per_job, math.ceil(repetitions / reps_per_agent))
 | 
					        node_reps = len(rep_ids)
 | 
				
			||||||
 | 
					        num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if num_p == 1:
 | 
					        if num_p == 1:
 | 
				
			||||||
            return [self.run_single(config, max_reps=reps_per_agent)]
 | 
					            self._run_single(config, rep_ids=rep_ids, p_ind=0)
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return Pool(processes=num_p).map(partial(self.run_single, config, max_reps=reps_per_agent), range(num_p))
 | 
					        procs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run_single(self, config, max_reps=-1, p_ind=0):
 | 
					        reps_done = 0
 | 
				
			||||||
        runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
 | 
					
 | 
				
			||||||
 | 
					        for p in range(num_p):
 | 
				
			||||||
 | 
					            num_reps = min(node_reps - reps_done, reps_per_agent)
 | 
				
			||||||
 | 
					            proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
 | 
				
			||||||
 | 
					            proc = Process(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
 | 
				
			||||||
 | 
					            proc.start()
 | 
				
			||||||
 | 
					            procs.append(proc)
 | 
				
			||||||
 | 
					            reps_done += num_reps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for proc in procs:
 | 
				
			||||||
 | 
					            proc.join()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _run_process(self, orig_config, rep_ids, p_ind):
 | 
				
			||||||
 | 
					        config = copy.deepcopy(orig_config)
 | 
				
			||||||
 | 
					        if self.consume(config, 'sweep.enable', False):
 | 
				
			||||||
 | 
					            sweepC = self.consume(config, 'sweep')
 | 
				
			||||||
 | 
					            project = config['wandb']['project']
 | 
				
			||||||
 | 
					            sweep_id = wandb.sweep(
 | 
				
			||||||
 | 
					                sweep=sweepC,
 | 
				
			||||||
 | 
					                project=project
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.consume(config, 'sweep', {})
 | 
				
			||||||
 | 
					            self._run_single(config, rep_ids, p_ind=p_ind)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _run_single(self, orig_config, rep_ids, p_ind):
 | 
				
			||||||
 | 
					        print(f'[P{p_ind}] I will work on reps {rep_ids}')
 | 
				
			||||||
 | 
					        runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
 | 
				
			||||||
 | 
					        project = self.consume(wandbC, 'project')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        runner = self.runners[runnerName]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for r in rep_ids:
 | 
				
			||||||
 | 
					            config = copy.deepcopy(orig_config)
 | 
				
			||||||
 | 
					            with wandb.init(
 | 
				
			||||||
 | 
					                project=project,
 | 
				
			||||||
 | 
					                config=config,
 | 
				
			||||||
 | 
					                **wandbC
 | 
				
			||||||
 | 
					            ) as run:
 | 
				
			||||||
 | 
					                runner(self, run, config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            assert config == {}, ('Config was not completely consumed: ', config)
 | 
				
			||||||
 | 
					        orig_config = config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _run_from_sweep(self, orig_config, p_ind):
 | 
				
			||||||
 | 
					        runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {})
 | 
				
			||||||
 | 
					        project = self.consume(wandbC, 'project')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        runner = self.runners[runnerName]
 | 
					        runner = self.runners[runnerName]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with wandb.init(
 | 
					        with wandb.init(
 | 
				
			||||||
            project=self.consume(wandbC, 'project'),
 | 
					            project=project,
 | 
				
			||||||
            config=config,
 | 
					 | 
				
			||||||
            **wandbC
 | 
					            **wandbC
 | 
				
			||||||
        ) as run:
 | 
					        ) as run:
 | 
				
			||||||
 | 
					            config = copy.deepcopy(orig_config)
 | 
				
			||||||
 | 
					            self.deep_update(config, wandb.config)
 | 
				
			||||||
            runner(self, run, config)
 | 
					            runner(self, run, config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert config == {}, ('Config was not completely consumed: ', config)
 | 
					            assert config == {}, ('Config was not completely consumed: ', config)
 | 
				
			||||||
 | 
					        orig_config = config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def from_args(self):
 | 
					    def from_args(self):
 | 
				
			||||||
        import argparse
 | 
					        import argparse
 | 
				
			||||||
@ -201,7 +245,7 @@ class Slate():
 | 
				
			|||||||
        parser.add_argument("experiment", nargs='?', default='DEFAULT')
 | 
					        parser.add_argument("experiment", nargs='?', default='DEFAULT')
 | 
				
			||||||
        parser.add_argument("-s", "--slurm", action="store_true")
 | 
					        parser.add_argument("-s", "--slurm", action="store_true")
 | 
				
			||||||
        parser.add_argument("-w", "--worker", action="store_true")
 | 
					        parser.add_argument("-w", "--worker", action="store_true")
 | 
				
			||||||
        parser.add_argument("-j", "--job_num", default=None)
 | 
					        parser.add_argument("-j", "--job_id", default=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        args = parser.parse_args()
 | 
					        args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -212,7 +256,7 @@ class Slate():
 | 
				
			|||||||
        if args.slurm:
 | 
					        if args.slurm:
 | 
				
			||||||
            self.run_slurm(args.config_file, args.experiment)
 | 
					            self.run_slurm(args.config_file, args.experiment)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.run_local(args.config_file, args.experiment, args.job_num)
 | 
					            self.run_local(args.config_file, args.experiment, args.job_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def print_config_runner(slate, run, config):
 | 
					def print_config_runner(slate, run, config):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user