Fix trpl test using wrong hps
This commit is contained in:
parent
54bab221ef
commit
e6d78083aa
@ -15,21 +15,24 @@ def test_trpl_instantiation():
|
|||||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||||
@pytest.mark.parametrize("max_kl", [0.01, 0.05])
|
@pytest.mark.parametrize("trust_region_bound_mean", [0.05, 0.1])
|
||||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl):
|
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
|
||||||
|
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
|
||||||
trpl = TRPL(
|
trpl = TRPL(
|
||||||
"CartPole-v1",
|
"CartPole-v1",
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
max_kl=max_kl
|
trust_region_bound_mean=trust_region_bound_mean,
|
||||||
|
trust_region_bound_cov=trust_region_bound_cov
|
||||||
)
|
)
|
||||||
assert trpl.learning_rate == learning_rate
|
assert trpl.learning_rate == learning_rate
|
||||||
assert trpl.n_steps == n_steps
|
assert trpl.n_steps == n_steps
|
||||||
assert trpl.batch_size == batch_size
|
assert trpl.batch_size == batch_size
|
||||||
assert trpl.gamma == gamma
|
assert trpl.gamma == gamma
|
||||||
assert trpl.max_kl == max_kl
|
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
||||||
|
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
||||||
|
|
||||||
def test_trpl_predict(simple_env):
|
def test_trpl_predict(simple_env):
|
||||||
trpl = TRPL("CartPole-v1")
|
trpl = TRPL("CartPole-v1")
|
||||||
|
Loading…
Reference in New Issue
Block a user