Allow pre validating runners before submitting to slurm

This commit is contained in:
Dominik Moritz Roth 2023-07-12 13:06:14 +02:00
parent 97137964fb
commit 6febf19800

View File

@ -10,6 +10,7 @@ from multiprocessing import Process
from threading import Thread from threading import Thread
import git import git
import datetime import datetime
from pprint import pprint
import pdb import pdb
d = pdb.set_trace d = pdb.set_trace
@ -38,8 +39,8 @@ else:
class Slate(): class Slate():
def __init__(self, runners): def __init__(self, runners):
self.runners = runners self.runners = runners
self.runners['printConfig'] = print_config_runner self.runners['printConfig'] = Print_Config_Runner
self.runners['pdb'] = pdb_runner self.runners['pdb'] = PDB_Runner
self._version = False self._version = False
def load_config(self, filename, name): def load_config(self, filename, name):
@ -182,6 +183,13 @@ class Slate():
schedC = self.consume(config, 'scheduler') schedC = self.consume(config, 'scheduler')
s_name = self.consume(slurmC, 'name') s_name = self.consume(slurmC, 'name')
runnerName, _ = self.consume(config, 'runner'), self.consume(config, 'wandb', {}, expand=True)
if self.consume(slurmC, 'pre_validate', True):
Runner = self.runners[runnerName]
runner = Runner(self, config)
runner.setup()
python_script = 'main.py' python_script = 'main.py'
sh_lines = ['#!/bin/bash'] sh_lines = ['#!/bin/bash']
sh_lines += self.consume(slurmC, 'sh_lines', []) sh_lines += self.consume(slurmC, 'sh_lines', [])
@ -251,7 +259,7 @@ class Slate():
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project') project = self.consume(wandbC, 'project')
runner = self.runners[runnerName] Runner = self.runners[runnerName]
for r in rep_ids: for r in rep_ids:
config = copy.deepcopy(orig_config) config = copy.deepcopy(orig_config)
@ -262,7 +270,9 @@ class Slate():
settings=wandb.Settings(start_method=WANDB_START_METHOD), settings=wandb.Settings(start_method=WANDB_START_METHOD),
**wandbC **wandbC
) as run: ) as run:
runner(self, run, config) runner = Runner(self, config)
runner.setup()
runner.run(run)
if config != {}: if config != {}:
msg = ('Config was not completely consumed: ', config) msg = ('Config was not completely consumed: ', config)
@ -276,7 +286,7 @@ class Slate():
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
project = self.consume(wandbC, 'project') project = self.consume(wandbC, 'project')
runner = self.runners[runnerName] Runner = self.runners[runnerName]
with wandb.init( with wandb.init(
project=project, project=project,
@ -286,7 +296,9 @@ class Slate():
) as run: ) as run:
config = copy.deepcopy(orig_config) config = copy.deepcopy(orig_config)
self.deep_update(config, wandb.config) self.deep_update(config, wandb.config)
runner(self, run, config) runner = Runner(self, config)
runner.setup()
runner.run(run)
if config != {}: if config != {}:
msg = ('Config was not completely consumed: ', config) msg = ('Config was not completely consumed: ', config)
@ -321,8 +333,22 @@ class Slate():
self.run_local(args.config_file, args.experiment, args.job_id) self.run_local(args.config_file, args.experiment, args.job_id)
def print_config_runner(slate, run, config): class Slate_Runner():
from pprint import pprint def __init__(self, slate, config):
self.slate = slate
self.config = config
def setup(self):
pass
def run(self, run):
pass
class Print_Config_Runner(Slate_Runner):
def run(self, run):
slate, config = self.slate, self.config
ptr = {'ptr': config} ptr = {'ptr': config}
pprint(config) pprint(config)
print('---') print('---')
@ -331,7 +357,8 @@ def print_config_runner(slate, run, config):
del config[k] del config[k]
def pdb_runner(slate, run, config): class PDB_Runner(Slate_Runner):
def run(self, run):
d() d()