From 3b138bba1e74f935b0387dc186c2ee196e16ea88 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 16 Aug 2024 12:39:57 +0200 Subject: [PATCH] Bug fixes + tests --- example.py | 2 +- example.yaml | 4 ++-- slate/slate.py | 15 ++------------- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/example.py b/example.py index 808863e..8dfe7e6 100644 --- a/example.py +++ b/example.py @@ -1,6 +1,6 @@ from slate import Slate -import fancy_gym +#import fancy_gym from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder diff --git a/example.yaml b/example.yaml index b6f35ff..d1bcb70 100644 --- a/example.yaml +++ b/example.yaml @@ -118,7 +118,7 @@ name: Test_Sweep project: JustTesting runner: printConfig scheduler: - repetitions: 16 + reps_per_version: 1 wandb: group: 'sweep' job_type: '{delta_desc}' @@ -135,7 +135,7 @@ name: Test_CTR project: JustTesting runner: printConfig scheduler: - repetitions: 16 + reps_per_version: 1 reps_per_agent: 2 agents_per_job: 2 wandb: diff --git a/slate/slate.py b/slate/slate.py index 742397c..8be814f 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -341,8 +341,7 @@ class Slate(): for r in rep_ids: self.run_id = r - config = copy.deepcopy(orig_config) - runnerConf = self._make_config_for_run(config, r) + runnerConf = copy.deepcopy(orig_config) wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE')) if 'job_type' in wandbC and len(wandbC['job_type']) > 62: wandbC['job_type'] = "..."+wandbC['job_type'][-50:] @@ -382,16 +381,6 @@ class Slate(): def _get_num_conv_versions(self, config): return len(self._make_configs_for_runs(config)) - def _make_config_for_run(self, config, r): - all_versions = self._make_configs_for_runs(config) - - i = r % len(all_versions) - print(f'[d] Running version {i}/{len(all_versions)} in run {r}') - cur_conf = all_versions[i] - if 'ablative' in cur_conf: - del cur_conf['ablative'] - return cur_conf - def _make_grid_versions(self, config): if 'grid' in config: return params_combine(config, 'grid', itertools.product) @@ -539,7 +528,7 @@ class Slate_Runner(): self.slate = slate self.config = config - def setup(self): + def setup(self, name): pass def run(self, run):