Allow pre validating runners before submitting to slurm

This commit is contained in:
Dominik Moritz Roth 2023-07-12 13:06:14 +02:00
parent 97137964fb
commit 6febf19800

View File

@ -10,6 +10,7 @@ from multiprocessing import Process
from threading import Thread
import git
import datetime
from pprint import pprint
import pdb
d = pdb.set_trace
@ -38,8 +39,8 @@ else:
class Slate():
def __init__(self, runners):
self.runners = runners
self.runners['printConfig'] = print_config_runner
self.runners['pdb'] = pdb_runner
self.runners['printConfig'] = Print_Config_Runner
self.runners['pdb'] = PDB_Runner
self._version = False
def load_config(self, filename, name):
@ -182,6 +183,13 @@ class Slate():
schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name')
runnerName, _ = self.consume(config, 'runner'), self.consume(config, 'wandb', {}, expand=True)
if self.consume(slurmC, 'pre_validate', True):
Runner = self.runners[runnerName]
runner = Runner(self, config)
runner.setup()
python_script = 'main.py'
sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', [])
@ -251,7 +259,7 @@ class Slate():
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
Runner = self.runners[runnerName]
for r in rep_ids:
config = copy.deepcopy(orig_config)
@ -262,7 +270,9 @@ class Slate():
settings=wandb.Settings(start_method=WANDB_START_METHOD),
**wandbC
) as run:
runner(self, run, config)
runner = Runner(self, config)
runner.setup()
runner.run(run)
if config != {}:
msg = ('Config was not completely consumed: ', config)
@ -276,7 +286,7 @@ class Slate():
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project')
runner = self.runners[runnerName]
Runner = self.runners[runnerName]
with wandb.init(
project=project,
@ -286,7 +296,9 @@ class Slate():
) as run:
config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config)
runner(self, run, config)
runner = Runner(self, config)
runner.setup()
runner.run(run)
if config != {}:
msg = ('Config was not completely consumed: ', config)
@ -321,18 +333,33 @@ class Slate():
self.run_local(args.config_file, args.experiment, args.job_id)
def print_config_runner(slate, run, config):
from pprint import pprint
ptr = {'ptr': config}
pprint(config)
print('---')
pprint(slate.consume(ptr, 'ptr', expand=True))
for k in list(config.keys()):
del config[k]
class Slate_Runner():
def __init__(self, slate, config):
self.slate = slate
self.config = config
def setup(self):
pass
def run(self, run):
pass
def pdb_runner(slate, run, config):
d()
class Print_Config_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
ptr = {'ptr': config}
pprint(config)
print('---')
pprint(slate.consume(ptr, 'ptr', expand=True))
for k in list(config.keys()):
del config[k]
class PDB_Runner(Slate_Runner):
def run(self, run):
d()
if __name__ == '__main__':