diff --git a/tests/test_ppo.py b/tests/test_ppo.py index 17b4b0d..5389370 100644 --- a/tests/test_ppo.py +++ b/tests/test_ppo.py @@ -6,4 +6,10 @@ from metastable_baselines2 import PPO @pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"]) def test_trpl(env_id): model = PPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) - model.learn(total_timesteps=500) \ No newline at end of file + model.learn(total_timesteps=500) + +@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2"]) +@pytest.mark.parametrize("par_strength", ['DIAG', 'FULL', 'CONT_DIAG', 'CONT_FULL']) +def test_ppo_pca(env_id, par_strength): + model = PPO("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': par_strength, 'skip_conditioning': True}), verbose=1) + model.learn(total_timesteps=100) diff --git a/tests/test_trpl.py b/tests/test_trpl.py index 07e715d..01f6d8d 100644 --- a/tests/test_trpl.py +++ b/tests/test_trpl.py @@ -3,7 +3,7 @@ import pytest from metastable_baselines2 import TRPL -PROJECTIONS = ["Frobenius", "Wasserstein"] #, "KL"] +PROJECTIONS = ["Frobenius", "Wasserstein"] # KL @pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"]) @pytest.mark.parametrize("projection", PROJECTIONS) @@ -18,3 +18,10 @@ def test_trpl(env_id, projection): def test_trpl_params(env_id, projection, mean_bound, cov_bound): model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1) model.learn(total_timesteps=100) + +@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2"]) +@pytest.mark.parametrize("projection", PROJECTIONS) +@pytest.mark.parametrize("par_strength", ['DIAG', 'FULL', 'CONT_DIAG', 'CONT_FULL']) +def test_trpl_pca(env_id, projection, par_strength): + model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': par_strength, 'skip_conditioning': True}), projection_class=projection, verbose=1) + model.learn(total_timesteps=100)