Bug fixes + tests

This commit is contained in:
Dominik Moritz Roth 2024-08-16 12:39:57 +02:00
parent f2b5837b56
commit 3b138bba1e
3 changed files with 5 additions and 16 deletions

View File

@ -1,6 +1,6 @@
from slate import Slate from slate import Slate
import fancy_gym #import fancy_gym
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder

View File

@ -118,7 +118,7 @@ name: Test_Sweep
project: JustTesting project: JustTesting
runner: printConfig runner: printConfig
scheduler: scheduler:
repetitions: 16 reps_per_version: 1
wandb: wandb:
group: 'sweep' group: 'sweep'
job_type: '{delta_desc}' job_type: '{delta_desc}'
@ -135,7 +135,7 @@ name: Test_CTR
project: JustTesting project: JustTesting
runner: printConfig runner: printConfig
scheduler: scheduler:
repetitions: 16 reps_per_version: 1
reps_per_agent: 2 reps_per_agent: 2
agents_per_job: 2 agents_per_job: 2
wandb: wandb:

View File

@ -341,8 +341,7 @@ class Slate():
for r in rep_ids: for r in rep_ids:
self.run_id = r self.run_id = r
config = copy.deepcopy(orig_config) runnerConf = copy.deepcopy(orig_config)
runnerConf = self._make_config_for_run(config, r)
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE')) wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
if 'job_type' in wandbC and len(wandbC['job_type']) > 62: if 'job_type' in wandbC and len(wandbC['job_type']) > 62:
wandbC['job_type'] = "..."+wandbC['job_type'][-50:] wandbC['job_type'] = "..."+wandbC['job_type'][-50:]
@ -382,16 +381,6 @@ class Slate():
def _get_num_conv_versions(self, config): def _get_num_conv_versions(self, config):
return len(self._make_configs_for_runs(config)) return len(self._make_configs_for_runs(config))
def _make_config_for_run(self, config, r):
all_versions = self._make_configs_for_runs(config)
i = r % len(all_versions)
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
cur_conf = all_versions[i]
if 'ablative' in cur_conf:
del cur_conf['ablative']
return cur_conf
def _make_grid_versions(self, config): def _make_grid_versions(self, config):
if 'grid' in config: if 'grid' in config:
return params_combine(config, 'grid', itertools.product) return params_combine(config, 'grid', itertools.product)
@ -539,7 +528,7 @@ class Slate_Runner():
self.slate = slate self.slate = slate
self.config = config self.config = config
def setup(self): def setup(self, name):
pass pass
def run(self, run): def run(self, run):