From 6febf1980094183e7a53662e1db5c8235f788838 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 12 Jul 2023 13:06:14 +0200 Subject: [PATCH] Allow pre validating runners before submitting to slurm --- slate/slate.py | 59 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/slate/slate.py b/slate/slate.py index d1a4ee3..a73c0d8 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -10,6 +10,7 @@ from multiprocessing import Process from threading import Thread import git import datetime +from pprint import pprint import pdb d = pdb.set_trace @@ -38,8 +39,8 @@ else: class Slate(): def __init__(self, runners): self.runners = runners - self.runners['printConfig'] = print_config_runner - self.runners['pdb'] = pdb_runner + self.runners['printConfig'] = Print_Config_Runner + self.runners['pdb'] = PDB_Runner self._version = False def load_config(self, filename, name): @@ -182,6 +183,13 @@ class Slate(): schedC = self.consume(config, 'scheduler') s_name = self.consume(slurmC, 'name') + runnerName, _ = self.consume(config, 'runner'), self.consume(config, 'wandb', {}, expand=True) + + if self.consume(slurmC, 'pre_validate', True): + Runner = self.runners[runnerName] + runner = Runner(self, config) + runner.setup() + python_script = 'main.py' sh_lines = ['#!/bin/bash'] sh_lines += self.consume(slurmC, 'sh_lines', []) @@ -251,7 +259,7 @@ class Slate(): runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) project = self.consume(wandbC, 'project') - runner = self.runners[runnerName] + Runner = self.runners[runnerName] for r in rep_ids: config = copy.deepcopy(orig_config) @@ -262,7 +270,9 @@ class Slate(): settings=wandb.Settings(start_method=WANDB_START_METHOD), **wandbC ) as run: - runner(self, run, config) + runner = Runner(self, config) + runner.setup() + runner.run(run) if config != {}: msg = ('Config was not completely consumed: ', config) @@ -276,7 +286,7 @@ class Slate(): runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True) project = self.consume(wandbC, 'project') - runner = self.runners[runnerName] + Runner = self.runners[runnerName] with wandb.init( project=project, @@ -286,7 +296,9 @@ class Slate(): ) as run: config = copy.deepcopy(orig_config) self.deep_update(config, wandb.config) - runner(self, run, config) + runner = Runner(self, config) + runner.setup() + runner.run(run) if config != {}: msg = ('Config was not completely consumed: ', config) @@ -321,18 +333,33 @@ class Slate(): self.run_local(args.config_file, args.experiment, args.job_id) -def print_config_runner(slate, run, config): - from pprint import pprint - ptr = {'ptr': config} - pprint(config) - print('---') - pprint(slate.consume(ptr, 'ptr', expand=True)) - for k in list(config.keys()): - del config[k] +class Slate_Runner(): + def __init__(self, slate, config): + self.slate = slate + self.config = config + + def setup(self): + pass + + def run(self, run): + pass -def pdb_runner(slate, run, config): - d() +class Print_Config_Runner(Slate_Runner): + def run(self, run): + slate, config = self.slate, self.config + + ptr = {'ptr': config} + pprint(config) + print('---') + pprint(slate.consume(ptr, 'ptr', expand=True)) + for k in list(config.keys()): + del config[k] + + +class PDB_Runner(Slate_Runner): + def run(self, run): + d() if __name__ == '__main__':