diff --git a/slate/slate.py b/slate/slate.py index bca09b8..b5ef5ec 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -44,6 +44,9 @@ class Slate(): } self.runners.update(runners) self._version = False + self.job_id = os.environ.get('SLURM_JOB_ID', False) + self.task_id = None + self._tmp_path = os.path.expandvars('$TMP') self.sweep_id = None def load_config(self, filename, name): @@ -104,10 +107,9 @@ class Slate(): def expand_vars(self, string, delta_desc='BASE', **kwargs): if isinstance(string, str): rand = int(random.random()*99999999) - tmp = os.path.expandvars('$TMP') if string == '{rand}': return rand - return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=tmp) + return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0)) return string def apply_nested(self, d, f): @@ -175,22 +177,23 @@ class Slate(): jobs_needed = math.ceil(reps / reps_per_job) return jobs_needed - def _reps_for_job(self, schedC, job_id, num_conv_versions): + def _reps_for_job(self, schedC, task_id, num_conv_versions): schedulerC = copy.deepcopy(schedC) num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions) reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions) - if job_id == None: + if task_id == None: return list(range(0, reps)) reps_for_job = [[] for i in range(num_jobs)] for i in range(reps): reps_for_job[i % num_jobs].append(i) - return reps_for_job[job_id] + return reps_for_job[task_id] - def run_local(self, filename, name, job_id, sweep_id): + def run_local(self, filename, name, task_id, sweep_id): + self.task_id = task_id config = self.load_config(filename, name) num_conv_versions = self._get_num_conv_versions(config) schedulerC = copy.deepcopy(config.get('scheduler', {})) - rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions) + rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions) self.sweep_id = sweep_id self._init_sweep(config) self._fork_processes(config, rep_ids) @@ -219,7 +222,7 @@ class Slate(): sh_lines += self.consume(slurmC, 'sh_lines', []) if venv := self.consume(slurmC, 'venv', False): sh_lines += [f'source activate {venv}'] - sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'] + sh_lines += [f'python3 {python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweet_id {self.sweep_id}'] script = "\n".join(sh_lines) num_jobs = self._calc_num_jobs(schedC, num_conv_versions) @@ -380,22 +383,23 @@ class Slate(): parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-w", "--worker", action="store_true") - parser.add_argument("-j", "--job_id", default=None, type=int) + parser.add_argument("-t", "--task_id", default=None, type=int) parser.add_argument("--sweep_id", default=None, type=str) args = parser.parse_args() - print(f'[i] I have job_id {args.job_id}') + print(f'[i] I have task_id {args.task_id}') print(f'[i] Running on version [git:{self.get_version()}]') if args.worker: raise Exception('Not yet implemented') assert args.config_file != None, 'Need to supply config file.' + if args.slurm: self.run_slurm(args.config_file, args.experiment) else: - self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id) + self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id) def params_combine(config: dict, key: str, iter_func):