Expand Tests

This commit is contained in:
Dominik Moritz Roth 2024-10-21 15:24:36 +02:00
parent e927afcc30
commit abc8dcbda1
2 changed files with 55 additions and 31 deletions

View File

@ -3,12 +3,15 @@ import numpy as np
from fancy_rl import PPO from fancy_rl import PPO
import gymnasium as gym import gymnasium as gym
@pytest.fixture
def simple_env(): def simple_env():
return gym.make('CartPole-v1') return gym.make('LunarLander-v2', continuous=True)
def test_ppo_instantiation(): def test_ppo_instantiation():
ppo = PPO("CartPole-v1") ppo = PPO(simple_env)
assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO) assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3]) @pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@ -19,7 +22,7 @@ def test_ppo_instantiation():
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3]) @pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range): def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO( ppo = PPO(
"CartPole-v1", simple_env,
learning_rate=learning_rate, learning_rate=learning_rate,
n_steps=n_steps, n_steps=n_steps,
batch_size=batch_size, batch_size=batch_size,
@ -34,26 +37,42 @@ def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_siz
assert ppo.gamma == gamma assert ppo.gamma == gamma
assert ppo.clip_range == clip_range assert ppo.clip_range == clip_range
def test_ppo_predict(simple_env): def test_ppo_predict():
ppo = PPO("CartPole-v1") ppo = PPO(simple_env)
obs, _ = simple_env.reset() env = ppo.make_env()
obs, _ = env.reset()
action, _ = ppo.predict(obs) action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray) assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape assert action.shape == env.action_space.shape
def test_ppo_learn(): def test_ppo_learn():
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32) ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = gym.make("CartPole-v1") env = ppo.make_env()
obs, _ = env.reset() obs, _ = env.reset()
for _ in range(64): for _ in range(64):
action, _ = ppo.predict(obs) action, _ = ppo.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action) obs, reward, done, truncated, _ = env.step(action)
ppo.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated: if done or truncated:
obs, _ = env.reset() obs, _ = env.reset()
loss = ppo.learn() def test_ppo_training():
assert isinstance(loss, dict) ppo = PPO(simple_env, total_timesteps=10000)
assert "policy_loss" in loss env = ppo.make_env()
assert "value_loss" in loss
initial_performance = evaluate_policy(ppo, env)
ppo.train()
final_performance = evaluate_policy(ppo, env)
assert final_performance > initial_performance, "PPO should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
return total_reward / n_eval_episodes

View File

@ -3,12 +3,15 @@ import numpy as np
from fancy_rl import TRPL from fancy_rl import TRPL
import gymnasium as gym import gymnasium as gym
@pytest.fixture
def simple_env(): def simple_env():
return gym.make('CartPole-v1') return gym.make('LunarLander-v2', continuous=True)
def test_trpl_instantiation(): def test_trpl_instantiation():
trpl = TRPL("CartPole-v1") trpl = TRPL(simple_env)
assert isinstance(trpl, TRPL)
def test_trpl_instantiation_from_str():
trpl = TRPL('MountainCarContinuous-v0')
assert isinstance(trpl, TRPL) assert isinstance(trpl, TRPL)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3]) @pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@ -19,7 +22,7 @@ def test_trpl_instantiation():
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001]) @pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov): def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
trpl = TRPL( trpl = TRPL(
"CartPole-v1", simple_env,
learning_rate=learning_rate, learning_rate=learning_rate,
n_steps=n_steps, n_steps=n_steps,
batch_size=batch_size, batch_size=batch_size,
@ -34,16 +37,17 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
def test_trpl_predict(simple_env): def test_trpl_predict():
trpl = TRPL("CartPole-v1") trpl = TRPL(simple_env)
obs, _ = simple_env.reset() env = trpl.make_env()
obs, _ = env.reset()
action, _ = trpl.predict(obs) action, _ = trpl.predict(obs)
assert isinstance(action, np.ndarray) assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape assert action.shape == env.action_space.shape
def test_trpl_learn(): def test_trpl_learn():
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32) trpl = TRPL(simple_env, n_steps=64, batch_size=32)
env = gym.make("CartPole-v1") env = trpl.make_env()
obs, _ = env.reset() obs, _ = env.reset()
for _ in range(64): for _ in range(64):
action, _ = trpl.predict(obs) action, _ = trpl.predict(obs)
@ -58,12 +62,13 @@ def test_trpl_learn():
assert "policy_loss" in loss assert "policy_loss" in loss
assert "value_loss" in loss assert "value_loss" in loss
def test_trpl_training(simple_env): def test_trpl_training():
trpl = TRPL("CartPole-v1", total_timesteps=10000) trpl = TRPL(simple_env, total_timesteps=10000)
env = trpl.make_env()
initial_performance = evaluate_policy(trpl, simple_env) initial_performance = evaluate_policy(trpl, env)
trpl.train() trpl.train()
final_performance = evaluate_policy(trpl, simple_env) final_performance = evaluate_policy(trpl, env)
assert final_performance > initial_performance, "TRPL should improve performance after training" assert final_performance > initial_performance, "TRPL should improve performance after training"