Seperate job_id and task_id and allow access from configs
This commit is contained in:
parent
a16f3889f6
commit
6f9fbe6b28
@ -44,6 +44,9 @@ class Slate():
|
||||
}
|
||||
self.runners.update(runners)
|
||||
self._version = False
|
||||
self.job_id = os.environ.get('SLURM_JOB_ID', False)
|
||||
self.task_id = None
|
||||
self._tmp_path = os.path.expandvars('$TMP')
|
||||
self.sweep_id = None
|
||||
|
||||
def load_config(self, filename, name):
|
||||
@ -104,10 +107,9 @@ class Slate():
|
||||
def expand_vars(self, string, delta_desc='BASE', **kwargs):
|
||||
if isinstance(string, str):
|
||||
rand = int(random.random()*99999999)
|
||||
tmp = os.path.expandvars('$TMP')
|
||||
if string == '{rand}':
|
||||
return rand
|
||||
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=tmp)
|
||||
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0))
|
||||
return string
|
||||
|
||||
def apply_nested(self, d, f):
|
||||
@ -175,22 +177,23 @@ class Slate():
|
||||
jobs_needed = math.ceil(reps / reps_per_job)
|
||||
return jobs_needed
|
||||
|
||||
def _reps_for_job(self, schedC, job_id, num_conv_versions):
|
||||
def _reps_for_job(self, schedC, task_id, num_conv_versions):
|
||||
schedulerC = copy.deepcopy(schedC)
|
||||
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
|
||||
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
||||
if job_id == None:
|
||||
if task_id == None:
|
||||
return list(range(0, reps))
|
||||
reps_for_job = [[] for i in range(num_jobs)]
|
||||
for i in range(reps):
|
||||
reps_for_job[i % num_jobs].append(i)
|
||||
return reps_for_job[job_id]
|
||||
return reps_for_job[task_id]
|
||||
|
||||
def run_local(self, filename, name, job_id, sweep_id):
|
||||
def run_local(self, filename, name, task_id, sweep_id):
|
||||
self.task_id = task_id
|
||||
config = self.load_config(filename, name)
|
||||
num_conv_versions = self._get_num_conv_versions(config)
|
||||
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
||||
rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions)
|
||||
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
|
||||
self.sweep_id = sweep_id
|
||||
self._init_sweep(config)
|
||||
self._fork_processes(config, rep_ids)
|
||||
@ -219,7 +222,7 @@ class Slate():
|
||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||
if venv := self.consume(slurmC, 'venv', False):
|
||||
sh_lines += [f'source activate {venv}']
|
||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
|
||||
sh_lines += [f'python3 {python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweet_id {self.sweep_id}']
|
||||
script = "\n".join(sh_lines)
|
||||
|
||||
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
|
||||
@ -380,22 +383,23 @@ class Slate():
|
||||
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
||||
parser.add_argument("-s", "--slurm", action="store_true")
|
||||
parser.add_argument("-w", "--worker", action="store_true")
|
||||
parser.add_argument("-j", "--job_id", default=None, type=int)
|
||||
parser.add_argument("-t", "--task_id", default=None, type=int)
|
||||
parser.add_argument("--sweep_id", default=None, type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f'[i] I have job_id {args.job_id}')
|
||||
print(f'[i] I have task_id {args.task_id}')
|
||||
print(f'[i] Running on version [git:{self.get_version()}]')
|
||||
|
||||
if args.worker:
|
||||
raise Exception('Not yet implemented')
|
||||
|
||||
assert args.config_file != None, 'Need to supply config file.'
|
||||
|
||||
if args.slurm:
|
||||
self.run_slurm(args.config_file, args.experiment)
|
||||
else:
|
||||
self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)
|
||||
self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id)
|
||||
|
||||
|
||||
def params_combine(config: dict, key: str, iter_func):
|
||||
|
Loading…
Reference in New Issue
Block a user