Bug fixes + tests
This commit is contained in:
parent
f2b5837b56
commit
3b138bba1e
@ -1,6 +1,6 @@
|
|||||||
from slate import Slate
|
from slate import Slate
|
||||||
|
|
||||||
import fancy_gym
|
#import fancy_gym
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
||||||
|
@ -118,7 +118,7 @@ name: Test_Sweep
|
|||||||
project: JustTesting
|
project: JustTesting
|
||||||
runner: printConfig
|
runner: printConfig
|
||||||
scheduler:
|
scheduler:
|
||||||
repetitions: 16
|
reps_per_version: 1
|
||||||
wandb:
|
wandb:
|
||||||
group: 'sweep'
|
group: 'sweep'
|
||||||
job_type: '{delta_desc}'
|
job_type: '{delta_desc}'
|
||||||
@ -135,7 +135,7 @@ name: Test_CTR
|
|||||||
project: JustTesting
|
project: JustTesting
|
||||||
runner: printConfig
|
runner: printConfig
|
||||||
scheduler:
|
scheduler:
|
||||||
repetitions: 16
|
reps_per_version: 1
|
||||||
reps_per_agent: 2
|
reps_per_agent: 2
|
||||||
agents_per_job: 2
|
agents_per_job: 2
|
||||||
wandb:
|
wandb:
|
||||||
|
@ -341,8 +341,7 @@ class Slate():
|
|||||||
|
|
||||||
for r in rep_ids:
|
for r in rep_ids:
|
||||||
self.run_id = r
|
self.run_id = r
|
||||||
config = copy.deepcopy(orig_config)
|
runnerConf = copy.deepcopy(orig_config)
|
||||||
runnerConf = self._make_config_for_run(config, r)
|
|
||||||
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
|
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
|
||||||
if 'job_type' in wandbC and len(wandbC['job_type']) > 62:
|
if 'job_type' in wandbC and len(wandbC['job_type']) > 62:
|
||||||
wandbC['job_type'] = "..."+wandbC['job_type'][-50:]
|
wandbC['job_type'] = "..."+wandbC['job_type'][-50:]
|
||||||
@ -382,16 +381,6 @@ class Slate():
|
|||||||
def _get_num_conv_versions(self, config):
|
def _get_num_conv_versions(self, config):
|
||||||
return len(self._make_configs_for_runs(config))
|
return len(self._make_configs_for_runs(config))
|
||||||
|
|
||||||
def _make_config_for_run(self, config, r):
|
|
||||||
all_versions = self._make_configs_for_runs(config)
|
|
||||||
|
|
||||||
i = r % len(all_versions)
|
|
||||||
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
|
|
||||||
cur_conf = all_versions[i]
|
|
||||||
if 'ablative' in cur_conf:
|
|
||||||
del cur_conf['ablative']
|
|
||||||
return cur_conf
|
|
||||||
|
|
||||||
def _make_grid_versions(self, config):
|
def _make_grid_versions(self, config):
|
||||||
if 'grid' in config:
|
if 'grid' in config:
|
||||||
return params_combine(config, 'grid', itertools.product)
|
return params_combine(config, 'grid', itertools.product)
|
||||||
@ -539,7 +528,7 @@ class Slate_Runner():
|
|||||||
self.slate = slate
|
self.slate = slate
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def setup(self):
|
def setup(self, name):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(self, run):
|
def run(self, run):
|
||||||
|
Loading…
Reference in New Issue
Block a user