Fix trpl test using wrong hps
This commit is contained in:
parent
54bab221ef
commit
e6d78083aa
@ -15,21 +15,24 @@ def test_trpl_instantiation():
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("max_kl", [0.01, 0.05])
|
||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl):
|
||||
@pytest.mark.parametrize("trust_region_bound_mean", [0.05, 0.1])
|
||||
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
|
||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
|
||||
trpl = TRPL(
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
gamma=gamma,
|
||||
max_kl=max_kl
|
||||
trust_region_bound_mean=trust_region_bound_mean,
|
||||
trust_region_bound_cov=trust_region_bound_cov
|
||||
)
|
||||
assert trpl.learning_rate == learning_rate
|
||||
assert trpl.n_steps == n_steps
|
||||
assert trpl.batch_size == batch_size
|
||||
assert trpl.gamma == gamma
|
||||
assert trpl.max_kl == max_kl
|
||||
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
||||
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
||||
|
||||
def test_trpl_predict(simple_env):
|
||||
trpl = TRPL("CartPole-v1")
|
||||
|
Loading…
Reference in New Issue
Block a user