Allow pre validating runners before submitting to slurm
This commit is contained in:
parent
97137964fb
commit
6febf19800
@ -10,6 +10,7 @@ from multiprocessing import Process
|
||||
from threading import Thread
|
||||
import git
|
||||
import datetime
|
||||
from pprint import pprint
|
||||
|
||||
import pdb
|
||||
d = pdb.set_trace
|
||||
@ -38,8 +39,8 @@ else:
|
||||
class Slate():
|
||||
def __init__(self, runners):
|
||||
self.runners = runners
|
||||
self.runners['printConfig'] = print_config_runner
|
||||
self.runners['pdb'] = pdb_runner
|
||||
self.runners['printConfig'] = Print_Config_Runner
|
||||
self.runners['pdb'] = PDB_Runner
|
||||
self._version = False
|
||||
|
||||
def load_config(self, filename, name):
|
||||
@ -182,6 +183,13 @@ class Slate():
|
||||
schedC = self.consume(config, 'scheduler')
|
||||
s_name = self.consume(slurmC, 'name')
|
||||
|
||||
runnerName, _ = self.consume(config, 'runner'), self.consume(config, 'wandb', {}, expand=True)
|
||||
|
||||
if self.consume(slurmC, 'pre_validate', True):
|
||||
Runner = self.runners[runnerName]
|
||||
runner = Runner(self, config)
|
||||
runner.setup()
|
||||
|
||||
python_script = 'main.py'
|
||||
sh_lines = ['#!/bin/bash']
|
||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||
@ -251,7 +259,7 @@ class Slate():
|
||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||
project = self.consume(wandbC, 'project')
|
||||
|
||||
runner = self.runners[runnerName]
|
||||
Runner = self.runners[runnerName]
|
||||
|
||||
for r in rep_ids:
|
||||
config = copy.deepcopy(orig_config)
|
||||
@ -262,7 +270,9 @@ class Slate():
|
||||
settings=wandb.Settings(start_method=WANDB_START_METHOD),
|
||||
**wandbC
|
||||
) as run:
|
||||
runner(self, run, config)
|
||||
runner = Runner(self, config)
|
||||
runner.setup()
|
||||
runner.run(run)
|
||||
|
||||
if config != {}:
|
||||
msg = ('Config was not completely consumed: ', config)
|
||||
@ -276,7 +286,7 @@ class Slate():
|
||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||
project = self.consume(wandbC, 'project')
|
||||
|
||||
runner = self.runners[runnerName]
|
||||
Runner = self.runners[runnerName]
|
||||
|
||||
with wandb.init(
|
||||
project=project,
|
||||
@ -286,7 +296,9 @@ class Slate():
|
||||
) as run:
|
||||
config = copy.deepcopy(orig_config)
|
||||
self.deep_update(config, wandb.config)
|
||||
runner(self, run, config)
|
||||
runner = Runner(self, config)
|
||||
runner.setup()
|
||||
runner.run(run)
|
||||
|
||||
if config != {}:
|
||||
msg = ('Config was not completely consumed: ', config)
|
||||
@ -321,18 +333,33 @@ class Slate():
|
||||
self.run_local(args.config_file, args.experiment, args.job_id)
|
||||
|
||||
|
||||
def print_config_runner(slate, run, config):
|
||||
from pprint import pprint
|
||||
ptr = {'ptr': config}
|
||||
pprint(config)
|
||||
print('---')
|
||||
pprint(slate.consume(ptr, 'ptr', expand=True))
|
||||
for k in list(config.keys()):
|
||||
del config[k]
|
||||
class Slate_Runner():
|
||||
def __init__(self, slate, config):
|
||||
self.slate = slate
|
||||
self.config = config
|
||||
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
def run(self, run):
|
||||
pass
|
||||
|
||||
|
||||
def pdb_runner(slate, run, config):
|
||||
d()
|
||||
class Print_Config_Runner(Slate_Runner):
|
||||
def run(self, run):
|
||||
slate, config = self.slate, self.config
|
||||
|
||||
ptr = {'ptr': config}
|
||||
pprint(config)
|
||||
print('---')
|
||||
pprint(slate.consume(ptr, 'ptr', expand=True))
|
||||
for k in list(config.keys()):
|
||||
del config[k]
|
||||
|
||||
|
||||
class PDB_Runner(Slate_Runner):
|
||||
def run(self, run):
|
||||
d()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user