Fix expand vars in array

This commit is contained in:
Dominik Moritz Roth 2023-07-12 11:07:33 +02:00
parent ee2efa6da9
commit e1af07a68d

View File

@ -8,6 +8,8 @@ import collections.abc
from functools import partial from functools import partial
from multiprocessing import Process from multiprocessing import Process
from threading import Thread from threading import Thread
import git
import datetime
import pdb import pdb
d = pdb.set_trace d = pdb.set_trace
@ -102,7 +104,9 @@ class Slate():
self.apply_nested(v, f) self.apply_nested(v, f)
elif isinstance(v, list): elif isinstance(v, list):
for i, e in enumerate(v): for i, e in enumerate(v):
self.apply_nested({'PTR': d[k][i]}, f) ptr = {'PTR': d[k][i]}
self.apply_nested(ptr, f)
d[k][i] = ptr['PTR']
else: else:
d[k] = f(v) d[k] = f(v)
@ -129,7 +133,7 @@ class Slate():
return val return val
child = conf.get(keys_arr[0], {}) child = conf.get(keys_arr[0], {})
child_keys = '.'.join(keys_arr[1:]) child_keys = '.'.join(keys_arr[1:])
return self.consume(child, child_keys, default=default, **kwargs) return self.consume(child, child_keys, default=default, expand=expand, **kwargs)
def _calc_num_jobs(self, schedC): def _calc_num_jobs(self, schedC):
schedulerC = copy.deepcopy(schedC) schedulerC = copy.deepcopy(schedC)
@ -151,13 +155,13 @@ class Slate():
reps_for_job[i % num_jobs].append(i) reps_for_job[i % num_jobs].append(i)
return reps_for_job[job_id-1] return reps_for_job[job_id-1]
def run_local(self, filename, name, job_id): def run_local(self, filename, name, sha, job_id):
config = self.load_config(filename, name) config = self.load_config(filename, name)
schedulerC = copy.deepcopy(config.get('scheduler', {})) schedulerC = copy.deepcopy(config.get('scheduler', {}))
rep_ids = self._reps_for_job(schedulerC, job_id) rep_ids = self._reps_for_job(schedulerC, job_id)
self._fork_processes(config, rep_ids) self._fork_processes(config, rep_ids)
def run_slurm(self, filename, name): def run_slurm(self, filename, name, sha):
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.' assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
config = self.load_config(filename, name) config = self.load_config(filename, name)
slurmC = self.consume(config, 'slurm', expand=True) slurmC = self.consume(config, 'slurm', expand=True)
@ -180,6 +184,8 @@ class Slate():
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC) job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
job_id = job.submit() job_id = job.submit()
print(f'[i] Job submitted to slurm with id {job_id}') print(f'[i] Job submitted to slurm with id {job_id}')
with open('job_hist.log', 'w') as f:
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{sha}] at {datetime.datetime.now()}')
def _fork_processes(self, config, rep_ids): def _fork_processes(self, config, rep_ids):
schedC = self.consume(config, 'scheduler') schedC = self.consume(config, 'scheduler')
@ -288,16 +294,20 @@ class Slate():
args = parser.parse_args() args = parser.parse_args()
print('I have job_id ', args.job_id) repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
print(f'[i] I have job_id {args.job_id}')
print(f'[i] Running on version [git:{sha}]')
if args.worker: if args.worker:
raise Exception('Not yet implemented') raise Exception('Not yet implemented')
assert args.config_file != None, 'Need to supply config file.' assert args.config_file != None, 'Need to supply config file.'
if args.slurm: if args.slurm:
self.run_slurm(args.config_file, args.experiment) self.run_slurm(args.config_file, args.experiment, sha)
else: else:
self.run_local(args.config_file, args.experiment, args.job_id) self.run_local(args.config_file, args.experiment, sha, args.job_id)
def print_config_runner(slate, run, config): def print_config_runner(slate, run, config):