Wrote tests
This commit is contained in:
parent
eed4363ddd
commit
1321e47b81
@ -6,4 +6,10 @@ from metastable_baselines2 import PPO
|
||||
@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"])
|
||||
def test_trpl(env_id):
|
||||
model = PPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model.learn(total_timesteps=500)
|
||||
model.learn(total_timesteps=500)
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2"])
|
||||
@pytest.mark.parametrize("par_strength", ['DIAG', 'FULL', 'CONT_DIAG', 'CONT_FULL'])
|
||||
def test_ppo_pca(env_id, par_strength):
|
||||
model = PPO("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': par_strength, 'skip_conditioning': True}), verbose=1)
|
||||
model.learn(total_timesteps=100)
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from metastable_baselines2 import TRPL
|
||||
|
||||
PROJECTIONS = ["Frobenius", "Wasserstein"] #, "KL"]
|
||||
PROJECTIONS = ["Frobenius", "Wasserstein"] # KL
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"])
|
||||
@pytest.mark.parametrize("projection", PROJECTIONS)
|
||||
@ -18,3 +18,10 @@ def test_trpl(env_id, projection):
|
||||
def test_trpl_params(env_id, projection, mean_bound, cov_bound):
|
||||
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1)
|
||||
model.learn(total_timesteps=100)
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2"])
|
||||
@pytest.mark.parametrize("projection", PROJECTIONS)
|
||||
@pytest.mark.parametrize("par_strength", ['DIAG', 'FULL', 'CONT_DIAG', 'CONT_FULL'])
|
||||
def test_trpl_pca(env_id, projection, par_strength):
|
||||
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': par_strength, 'skip_conditioning': True}), projection_class=projection, verbose=1)
|
||||
model.learn(total_timesteps=100)
|
||||
|
Loading…
Reference in New Issue
Block a user