diff --git a/slate/slate.py b/slate/slate.py index 798c884..d1a4ee3 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -39,6 +39,8 @@ class Slate(): def __init__(self, runners): self.runners = runners self.runners['printConfig'] = print_config_runner + self.runners['pdb'] = pdb_runner + self._version = False def load_config(self, filename, name): config, stack = self._load_config(filename, name) @@ -140,6 +142,13 @@ class Slate(): child_keys = '.'.join(keys_arr[1:]) return self.consume(child, child_keys, default=default, expand=expand, **kwargs) + def get_version(self): + if not self._version: + repo = git.Repo(search_parent_directories=True) + sha = repo.head.object.hexsha + self._version = sha + return self._version + def _calc_num_jobs(self, schedC): schedulerC = copy.deepcopy(schedC) reps = self.consume(schedulerC, 'repetitions', 1) @@ -160,13 +169,13 @@ class Slate(): reps_for_job[i % num_jobs].append(i) return reps_for_job[job_id-1] - def run_local(self, filename, name, sha, job_id): + def run_local(self, filename, name, job_id): config = self.load_config(filename, name) schedulerC = copy.deepcopy(config.get('scheduler', {})) rep_ids = self._reps_for_job(schedulerC, job_id) self._fork_processes(config, rep_ids) - def run_slurm(self, filename, name, sha): + def run_slurm(self, filename, name): assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' config = self.load_config(filename, name) slurmC = self.consume(config, 'slurm', expand=True) @@ -190,7 +199,7 @@ class Slate(): job_id = job.submit() print(f'[>] Job submitted to slurm with id {job_id}') with open('job_hist.log', 'a') as f: - f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{sha}] at {datetime.datetime.now()}\n') + f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n') def _fork_processes(self, config, rep_ids): schedC = self.consume(config, 'scheduler') @@ -299,20 +308,17 @@ class Slate(): args = parser.parse_args() - repo = git.Repo(search_parent_directories=True) - sha = repo.head.object.hexsha - print(f'[i] I have job_id {args.job_id}') - print(f'[i] Running on version [git:{sha}]') + print(f'[i] Running on version [git:{self.get_version()}]') if args.worker: raise Exception('Not yet implemented') assert args.config_file != None, 'Need to supply config file.' if args.slurm: - self.run_slurm(args.config_file, args.experiment, sha) + self.run_slurm(args.config_file, args.experiment) else: - self.run_local(args.config_file, args.experiment, sha, args.job_id) + self.run_local(args.config_file, args.experiment, args.job_id) def print_config_runner(slate, run, config): @@ -325,5 +331,9 @@ def print_config_runner(slate, run, config): del config[k] +def pdb_runner(slate, run, config): + d() + + if __name__ == '__main__': raise Exception('You are using it wrong...')