Bug fixes + tests

This commit is contained in:
Dominik Moritz Roth 2024-08-16 12:39:57 +02:00
parent f2b5837b56
commit 3b138bba1e
3 changed files with 5 additions and 16 deletions

View File

@ -1,6 +1,6 @@
from slate import Slate
import fancy_gym
#import fancy_gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder

View File

@ -118,7 +118,7 @@ name: Test_Sweep
project: JustTesting
runner: printConfig
scheduler:
repetitions: 16
reps_per_version: 1
wandb:
group: 'sweep'
job_type: '{delta_desc}'
@ -135,7 +135,7 @@ name: Test_CTR
project: JustTesting
runner: printConfig
scheduler:
repetitions: 16
reps_per_version: 1
reps_per_agent: 2
agents_per_job: 2
wandb:

View File

@ -341,8 +341,7 @@ class Slate():
for r in rep_ids:
self.run_id = r
config = copy.deepcopy(orig_config)
runnerConf = self._make_config_for_run(config, r)
runnerConf = copy.deepcopy(orig_config)
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
if 'job_type' in wandbC and len(wandbC['job_type']) > 62:
wandbC['job_type'] = "..."+wandbC['job_type'][-50:]
@ -382,16 +381,6 @@ class Slate():
def _get_num_conv_versions(self, config):
return len(self._make_configs_for_runs(config))
def _make_config_for_run(self, config, r):
all_versions = self._make_configs_for_runs(config)
i = r % len(all_versions)
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
cur_conf = all_versions[i]
if 'ablative' in cur_conf:
del cur_conf['ablative']
return cur_conf
def _make_grid_versions(self, config):
if 'grid' in config:
return params_combine(config, 'grid', itertools.product)
@ -539,7 +528,7 @@ class Slate_Runner():
self.slate = slate
self.config = config
def setup(self):
def setup(self, name):
pass
def run(self, run):