From e6d78083aaf5abe53ec8cc907c6876bd4a50bc43 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 28 Aug 2024 11:58:58 +0200 Subject: [PATCH] Fix trpl test using wrong hps --- test/test_trpl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_trpl.py b/test/test_trpl.py index 1ee3082..c7f640b 100644 --- a/test/test_trpl.py +++ b/test/test_trpl.py @@ -15,21 +15,24 @@ def test_trpl_instantiation(): @pytest.mark.parametrize("n_steps", [1024, 2048]) @pytest.mark.parametrize("batch_size", [32, 64, 128]) @pytest.mark.parametrize("gamma", [0.95, 0.99]) -@pytest.mark.parametrize("max_kl", [0.01, 0.05]) -def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl): +@pytest.mark.parametrize("trust_region_bound_mean", [0.05, 0.1]) +@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001]) +def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov): trpl = TRPL( "CartPole-v1", learning_rate=learning_rate, n_steps=n_steps, batch_size=batch_size, gamma=gamma, - max_kl=max_kl + trust_region_bound_mean=trust_region_bound_mean, + trust_region_bound_cov=trust_region_bound_cov ) assert trpl.learning_rate == learning_rate assert trpl.n_steps == n_steps assert trpl.batch_size == batch_size assert trpl.gamma == gamma - assert trpl.max_kl == max_kl + assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean + assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov def test_trpl_predict(simple_env): trpl = TRPL("CartPole-v1")