Expand Tests
This commit is contained in:
parent
e927afcc30
commit
abc8dcbda1
@ -3,12 +3,15 @@ import numpy as np
|
|||||||
from fancy_rl import PPO
|
from fancy_rl import PPO
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return gym.make('CartPole-v1')
|
return gym.make('LunarLander-v2', continuous=True)
|
||||||
|
|
||||||
def test_ppo_instantiation():
|
def test_ppo_instantiation():
|
||||||
ppo = PPO("CartPole-v1")
|
ppo = PPO(simple_env)
|
||||||
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
|
def test_ppo_instantiation_from_str():
|
||||||
|
ppo = PPO('CartPole-v1')
|
||||||
assert isinstance(ppo, PPO)
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||||
@ -19,7 +22,7 @@ def test_ppo_instantiation():
|
|||||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
||||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
||||||
ppo = PPO(
|
ppo = PPO(
|
||||||
"CartPole-v1",
|
simple_env,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -34,26 +37,42 @@ def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_siz
|
|||||||
assert ppo.gamma == gamma
|
assert ppo.gamma == gamma
|
||||||
assert ppo.clip_range == clip_range
|
assert ppo.clip_range == clip_range
|
||||||
|
|
||||||
def test_ppo_predict(simple_env):
|
def test_ppo_predict():
|
||||||
ppo = PPO("CartPole-v1")
|
ppo = PPO(simple_env)
|
||||||
obs, _ = simple_env.reset()
|
env = ppo.make_env()
|
||||||
|
obs, _ = env.reset()
|
||||||
action, _ = ppo.predict(obs)
|
action, _ = ppo.predict(obs)
|
||||||
assert isinstance(action, np.ndarray)
|
assert isinstance(action, np.ndarray)
|
||||||
assert action.shape == simple_env.action_space.shape
|
assert action.shape == env.action_space.shape
|
||||||
|
|
||||||
def test_ppo_learn():
|
def test_ppo_learn():
|
||||||
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
|
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||||
env = gym.make("CartPole-v1")
|
env = ppo.make_env()
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
for _ in range(64):
|
for _ in range(64):
|
||||||
action, _ = ppo.predict(obs)
|
action, _ = ppo.predict(obs)
|
||||||
next_obs, reward, done, truncated, _ = env.step(action)
|
obs, reward, done, truncated, _ = env.step(action)
|
||||||
ppo.store_transition(obs, action, reward, done, next_obs)
|
|
||||||
obs = next_obs
|
|
||||||
if done or truncated:
|
if done or truncated:
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
|
|
||||||
loss = ppo.learn()
|
def test_ppo_training():
|
||||||
assert isinstance(loss, dict)
|
ppo = PPO(simple_env, total_timesteps=10000)
|
||||||
assert "policy_loss" in loss
|
env = ppo.make_env()
|
||||||
assert "value_loss" in loss
|
|
||||||
|
initial_performance = evaluate_policy(ppo, env)
|
||||||
|
ppo.train()
|
||||||
|
final_performance = evaluate_policy(ppo, env)
|
||||||
|
|
||||||
|
assert final_performance > initial_performance, "PPO should improve performance after training"
|
||||||
|
|
||||||
|
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||||
|
total_reward = 0
|
||||||
|
for _ in range(n_eval_episodes):
|
||||||
|
obs, _ = env.reset()
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action, _ = policy.predict(obs)
|
||||||
|
obs, reward, terminated, truncated, _ = env.step(action)
|
||||||
|
total_reward += reward
|
||||||
|
done = terminated or truncated
|
||||||
|
return total_reward / n_eval_episodes
|
@ -3,12 +3,15 @@ import numpy as np
|
|||||||
from fancy_rl import TRPL
|
from fancy_rl import TRPL
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return gym.make('CartPole-v1')
|
return gym.make('LunarLander-v2', continuous=True)
|
||||||
|
|
||||||
def test_trpl_instantiation():
|
def test_trpl_instantiation():
|
||||||
trpl = TRPL("CartPole-v1")
|
trpl = TRPL(simple_env)
|
||||||
|
assert isinstance(trpl, TRPL)
|
||||||
|
|
||||||
|
def test_trpl_instantiation_from_str():
|
||||||
|
trpl = TRPL('MountainCarContinuous-v0')
|
||||||
assert isinstance(trpl, TRPL)
|
assert isinstance(trpl, TRPL)
|
||||||
|
|
||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||||
@ -19,7 +22,7 @@ def test_trpl_instantiation():
|
|||||||
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
|
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
|
||||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
|
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
|
||||||
trpl = TRPL(
|
trpl = TRPL(
|
||||||
"CartPole-v1",
|
simple_env,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -34,16 +37,17 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
|
|||||||
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
||||||
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
||||||
|
|
||||||
def test_trpl_predict(simple_env):
|
def test_trpl_predict():
|
||||||
trpl = TRPL("CartPole-v1")
|
trpl = TRPL(simple_env)
|
||||||
obs, _ = simple_env.reset()
|
env = trpl.make_env()
|
||||||
|
obs, _ = env.reset()
|
||||||
action, _ = trpl.predict(obs)
|
action, _ = trpl.predict(obs)
|
||||||
assert isinstance(action, np.ndarray)
|
assert isinstance(action, np.ndarray)
|
||||||
assert action.shape == simple_env.action_space.shape
|
assert action.shape == env.action_space.shape
|
||||||
|
|
||||||
def test_trpl_learn():
|
def test_trpl_learn():
|
||||||
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
|
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
|
||||||
env = gym.make("CartPole-v1")
|
env = trpl.make_env()
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
for _ in range(64):
|
for _ in range(64):
|
||||||
action, _ = trpl.predict(obs)
|
action, _ = trpl.predict(obs)
|
||||||
@ -58,12 +62,13 @@ def test_trpl_learn():
|
|||||||
assert "policy_loss" in loss
|
assert "policy_loss" in loss
|
||||||
assert "value_loss" in loss
|
assert "value_loss" in loss
|
||||||
|
|
||||||
def test_trpl_training(simple_env):
|
def test_trpl_training():
|
||||||
trpl = TRPL("CartPole-v1", total_timesteps=10000)
|
trpl = TRPL(simple_env, total_timesteps=10000)
|
||||||
|
env = trpl.make_env()
|
||||||
|
|
||||||
initial_performance = evaluate_policy(trpl, simple_env)
|
initial_performance = evaluate_policy(trpl, env)
|
||||||
trpl.train()
|
trpl.train()
|
||||||
final_performance = evaluate_policy(trpl, simple_env)
|
final_performance = evaluate_policy(trpl, env)
|
||||||
|
|
||||||
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user