diff --git a/tests/test_ppo.py b/tests/test_ppo.py new file mode 100644 index 0000000..17b4b0d --- /dev/null +++ b/tests/test_ppo.py @@ -0,0 +1,9 @@ +import gymnasium as gym +import pytest + +from metastable_baselines2 import PPO + +@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"]) +def test_trpl(env_id): + model = PPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) + model.learn(total_timesteps=500) \ No newline at end of file diff --git a/tests/test_trpl.py b/tests/test_trpl.py new file mode 100644 index 0000000..07e715d --- /dev/null +++ b/tests/test_trpl.py @@ -0,0 +1,20 @@ +import gymnasium as gym +import pytest + +from metastable_baselines2 import TRPL + +PROJECTIONS = ["Frobenius", "Wasserstein"] #, "KL"] + +@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"]) +@pytest.mark.parametrize("projection", PROJECTIONS) +def test_trpl(env_id, projection): + model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, verbose=1) + model.learn(total_timesteps=500) + +@pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2"]) +@pytest.mark.parametrize("projection", PROJECTIONS) +@pytest.mark.parametrize("mean_bound", [0.03, 0.06]) +@pytest.mark.parametrize("cov_bound", [1.0e-3, 2.0e-3]) +def test_trpl_params(env_id, projection, mean_bound, cov_bound): + model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1) + model.learn(total_timesteps=100)