Allow pre validating runners before submitting to slurm
This commit is contained in:
parent
97137964fb
commit
6febf19800
@ -10,6 +10,7 @@ from multiprocessing import Process
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
import git
|
import git
|
||||||
import datetime
|
import datetime
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
import pdb
|
import pdb
|
||||||
d = pdb.set_trace
|
d = pdb.set_trace
|
||||||
@ -38,8 +39,8 @@ else:
|
|||||||
class Slate():
|
class Slate():
|
||||||
def __init__(self, runners):
|
def __init__(self, runners):
|
||||||
self.runners = runners
|
self.runners = runners
|
||||||
self.runners['printConfig'] = print_config_runner
|
self.runners['printConfig'] = Print_Config_Runner
|
||||||
self.runners['pdb'] = pdb_runner
|
self.runners['pdb'] = PDB_Runner
|
||||||
self._version = False
|
self._version = False
|
||||||
|
|
||||||
def load_config(self, filename, name):
|
def load_config(self, filename, name):
|
||||||
@ -182,6 +183,13 @@ class Slate():
|
|||||||
schedC = self.consume(config, 'scheduler')
|
schedC = self.consume(config, 'scheduler')
|
||||||
s_name = self.consume(slurmC, 'name')
|
s_name = self.consume(slurmC, 'name')
|
||||||
|
|
||||||
|
runnerName, _ = self.consume(config, 'runner'), self.consume(config, 'wandb', {}, expand=True)
|
||||||
|
|
||||||
|
if self.consume(slurmC, 'pre_validate', True):
|
||||||
|
Runner = self.runners[runnerName]
|
||||||
|
runner = Runner(self, config)
|
||||||
|
runner.setup()
|
||||||
|
|
||||||
python_script = 'main.py'
|
python_script = 'main.py'
|
||||||
sh_lines = ['#!/bin/bash']
|
sh_lines = ['#!/bin/bash']
|
||||||
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
||||||
@ -251,7 +259,7 @@ class Slate():
|
|||||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||||
project = self.consume(wandbC, 'project')
|
project = self.consume(wandbC, 'project')
|
||||||
|
|
||||||
runner = self.runners[runnerName]
|
Runner = self.runners[runnerName]
|
||||||
|
|
||||||
for r in rep_ids:
|
for r in rep_ids:
|
||||||
config = copy.deepcopy(orig_config)
|
config = copy.deepcopy(orig_config)
|
||||||
@ -262,7 +270,9 @@ class Slate():
|
|||||||
settings=wandb.Settings(start_method=WANDB_START_METHOD),
|
settings=wandb.Settings(start_method=WANDB_START_METHOD),
|
||||||
**wandbC
|
**wandbC
|
||||||
) as run:
|
) as run:
|
||||||
runner(self, run, config)
|
runner = Runner(self, config)
|
||||||
|
runner.setup()
|
||||||
|
runner.run(run)
|
||||||
|
|
||||||
if config != {}:
|
if config != {}:
|
||||||
msg = ('Config was not completely consumed: ', config)
|
msg = ('Config was not completely consumed: ', config)
|
||||||
@ -276,7 +286,7 @@ class Slate():
|
|||||||
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
||||||
project = self.consume(wandbC, 'project')
|
project = self.consume(wandbC, 'project')
|
||||||
|
|
||||||
runner = self.runners[runnerName]
|
Runner = self.runners[runnerName]
|
||||||
|
|
||||||
with wandb.init(
|
with wandb.init(
|
||||||
project=project,
|
project=project,
|
||||||
@ -286,7 +296,9 @@ class Slate():
|
|||||||
) as run:
|
) as run:
|
||||||
config = copy.deepcopy(orig_config)
|
config = copy.deepcopy(orig_config)
|
||||||
self.deep_update(config, wandb.config)
|
self.deep_update(config, wandb.config)
|
||||||
runner(self, run, config)
|
runner = Runner(self, config)
|
||||||
|
runner.setup()
|
||||||
|
runner.run(run)
|
||||||
|
|
||||||
if config != {}:
|
if config != {}:
|
||||||
msg = ('Config was not completely consumed: ', config)
|
msg = ('Config was not completely consumed: ', config)
|
||||||
@ -321,18 +333,33 @@ class Slate():
|
|||||||
self.run_local(args.config_file, args.experiment, args.job_id)
|
self.run_local(args.config_file, args.experiment, args.job_id)
|
||||||
|
|
||||||
|
|
||||||
def print_config_runner(slate, run, config):
|
class Slate_Runner():
|
||||||
from pprint import pprint
|
def __init__(self, slate, config):
|
||||||
ptr = {'ptr': config}
|
self.slate = slate
|
||||||
pprint(config)
|
self.config = config
|
||||||
print('---')
|
|
||||||
pprint(slate.consume(ptr, 'ptr', expand=True))
|
def setup(self):
|
||||||
for k in list(config.keys()):
|
pass
|
||||||
del config[k]
|
|
||||||
|
def run(self, run):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def pdb_runner(slate, run, config):
|
class Print_Config_Runner(Slate_Runner):
|
||||||
d()
|
def run(self, run):
|
||||||
|
slate, config = self.slate, self.config
|
||||||
|
|
||||||
|
ptr = {'ptr': config}
|
||||||
|
pprint(config)
|
||||||
|
print('---')
|
||||||
|
pprint(slate.consume(ptr, 'ptr', expand=True))
|
||||||
|
for k in list(config.keys()):
|
||||||
|
del config[k]
|
||||||
|
|
||||||
|
|
||||||
|
class PDB_Runner(Slate_Runner):
|
||||||
|
def run(self, run):
|
||||||
|
d()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user