Fix trpl test using wrong hps

This commit is contained in:
Dominik Moritz Roth 2024-08-28 11:58:58 +02:00
parent 54bab221ef
commit e6d78083aa

View File

@ -15,21 +15,24 @@ def test_trpl_instantiation():
@pytest.mark.parametrize("n_steps", [1024, 2048]) @pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128]) @pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("gamma", [0.95, 0.99]) @pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("max_kl", [0.01, 0.05]) @pytest.mark.parametrize("trust_region_bound_mean", [0.05, 0.1])
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl): @pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
trpl = TRPL( trpl = TRPL(
"CartPole-v1", "CartPole-v1",
learning_rate=learning_rate, learning_rate=learning_rate,
n_steps=n_steps, n_steps=n_steps,
batch_size=batch_size, batch_size=batch_size,
gamma=gamma, gamma=gamma,
max_kl=max_kl trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
) )
assert trpl.learning_rate == learning_rate assert trpl.learning_rate == learning_rate
assert trpl.n_steps == n_steps assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size assert trpl.batch_size == batch_size
assert trpl.gamma == gamma assert trpl.gamma == gamma
assert trpl.max_kl == max_kl assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
def test_trpl_predict(simple_env): def test_trpl_predict(simple_env):
trpl = TRPL("CartPole-v1") trpl = TRPL("CartPole-v1")