Seperate job_id and task_id and allow access from configs

This commit is contained in:
Dominik Moritz Roth 2023-07-31 15:34:10 +02:00
parent a16f3889f6
commit 6f9fbe6b28

View File

@ -44,6 +44,9 @@ class Slate():
}
self.runners.update(runners)
self._version = False
self.job_id = os.environ.get('SLURM_JOB_ID', False)
self.task_id = None
self._tmp_path = os.path.expandvars('$TMP')
self.sweep_id = None
def load_config(self, filename, name):
@ -104,10 +107,9 @@ class Slate():
def expand_vars(self, string, delta_desc='BASE', **kwargs):
if isinstance(string, str):
rand = int(random.random()*99999999)
tmp = os.path.expandvars('$TMP')
if string == '{rand}':
return rand
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=tmp)
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0))
return string
def apply_nested(self, d, f):
@ -175,22 +177,23 @@ class Slate():
jobs_needed = math.ceil(reps / reps_per_job)
return jobs_needed
def _reps_for_job(self, schedC, job_id, num_conv_versions):
def _reps_for_job(self, schedC, task_id, num_conv_versions):
schedulerC = copy.deepcopy(schedC)
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
if job_id == None:
if task_id == None:
return list(range(0, reps))
reps_for_job = [[] for i in range(num_jobs)]
for i in range(reps):
reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id]
return reps_for_job[task_id]
def run_local(self, filename, name, job_id, sweep_id):
def run_local(self, filename, name, task_id, sweep_id):
self.task_id = task_id
config = self.load_config(filename, name)
num_conv_versions = self._get_num_conv_versions(config)
schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions)
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
self.sweep_id = sweep_id
self._init_sweep(config)
self._fork_processes(config, rep_ids)
@ -219,7 +222,7 @@ class Slate():
sh_lines += self.consume(slurmC, 'sh_lines', [])
if venv := self.consume(slurmC, 'venv', False):
sh_lines += [f'source activate {venv}']
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
sh_lines += [f'python3 {python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweet_id {self.sweep_id}']
script = "\n".join(sh_lines)
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
@ -380,22 +383,23 @@ class Slate():
parser.add_argument("experiment", nargs='?', default='DEFAULT')
parser.add_argument("-s", "--slurm", action="store_true")
parser.add_argument("-w", "--worker", action="store_true")
parser.add_argument("-j", "--job_id", default=None, type=int)
parser.add_argument("-t", "--task_id", default=None, type=int)
parser.add_argument("--sweep_id", default=None, type=str)
args = parser.parse_args()
print(f'[i] I have job_id {args.job_id}')
print(f'[i] I have task_id {args.task_id}')
print(f'[i] Running on version [git:{self.get_version()}]')
if args.worker:
raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.'
if args.slurm:
self.run_slurm(args.config_file, args.experiment)
else:
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id)
def params_combine(config: dict, key: str, iter_func):