Seperate job_id and task_id and allow access from configs

This commit is contained in:
Dominik Moritz Roth 2023-07-31 15:34:10 +02:00
parent a16f3889f6
commit 6f9fbe6b28

View File

@ -44,6 +44,9 @@ class Slate():
} }
self.runners.update(runners) self.runners.update(runners)
self._version = False self._version = False
self.job_id = os.environ.get('SLURM_JOB_ID', False)
self.task_id = None
self._tmp_path = os.path.expandvars('$TMP')
self.sweep_id = None self.sweep_id = None
def load_config(self, filename, name): def load_config(self, filename, name):
@ -104,10 +107,9 @@ class Slate():
def expand_vars(self, string, delta_desc='BASE', **kwargs): def expand_vars(self, string, delta_desc='BASE', **kwargs):
if isinstance(string, str): if isinstance(string, str):
rand = int(random.random()*99999999) rand = int(random.random()*99999999)
tmp = os.path.expandvars('$TMP')
if string == '{rand}': if string == '{rand}':
return rand return rand
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=tmp) return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0))
return string return string
def apply_nested(self, d, f): def apply_nested(self, d, f):
@ -175,22 +177,23 @@ class Slate():
jobs_needed = math.ceil(reps / reps_per_job) jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed return jobs_needed
def _reps_for_job(self, schedC, job_id, num_conv_versions): def _reps_for_job(self, schedC, task_id, num_conv_versions):
schedulerC = copy.deepcopy(schedC) schedulerC = copy.deepcopy(schedC)
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions) num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions) reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
if job_id == None: if task_id == None:
return list(range(0, reps)) return list(range(0, reps))
reps_for_job = [[] for i in range(num_jobs)] reps_for_job = [[] for i in range(num_jobs)]
for i in range(reps): for i in range(reps):
reps_for_job[i % num_jobs].append(i) reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id] return reps_for_job[task_id]
def run_local(self, filename, name, job_id, sweep_id): def run_local(self, filename, name, task_id, sweep_id):
self.task_id = task_id
config = self.load_config(filename, name) config = self.load_config(filename, name)
num_conv_versions = self._get_num_conv_versions(config) num_conv_versions = self._get_num_conv_versions(config)
schedulerC = copy.deepcopy(config.get('scheduler', {})) schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions) rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
self.sweep_id = sweep_id self.sweep_id = sweep_id
self._init_sweep(config) self._init_sweep(config)
self._fork_processes(config, rep_ids) self._fork_processes(config, rep_ids)
@ -219,7 +222,7 @@ class Slate():
sh_lines += self.consume(slurmC, 'sh_lines', []) sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False): if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}'] sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}'] sh_lines += [f'python3 {python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweet_id {self.sweep_id}']
script = "\n".join(sh_lines) script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC, num_conv_versions) num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
@ -380,22 +383,23 @@ class Slate():
parser.add_argument("experiment", nargs='?', default='DEFAULT') parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true") parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true") parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_id", default=None, type=int) parser.add_argument("-t", "--task_id", default=None, type=int)
parser.add_argument("--sweep_id", default=None, type=str) parser.add_argument("--sweep_id", default=None, type=str)
args = parser.parse_args() args = parser.parse_args()
print(f'[i] I have job_id {args.job_id}') print(f'[i] I have task_id {args.task_id}')
print(f'[i] Running on version [git:{self.get_version()}]') print(f'[i] Running on version [git:{self.get_version()}]')
if args.worker: if args.worker:
raise Exception('Not yet implemented') raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.' assert args.config_file != None, 'Need to supply config file.'
if args.slurm: if args.slurm:
self.run_slurm(args.config_file, args.experiment) self.run_slurm(args.config_file, args.experiment)
else: else:
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id) self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id)
def params_combine(config: dict, key: str, iter_func): def params_combine(config: dict, key: str, iter_func):