From abc8dcbda1205591378938c660b307582633579b Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 21 Oct 2024 15:24:36 +0200 Subject: [PATCH] Expand Tests --- test/test_ppo.py | 53 ++++++++++++++++++++++++++++++++--------------- test/test_trpl.py | 33 ++++++++++++++++------------- 2 files changed, 55 insertions(+), 31 deletions(-) diff --git a/test/test_ppo.py b/test/test_ppo.py index e71d58a..68bddb1 100644 --- a/test/test_ppo.py +++ b/test/test_ppo.py @@ -3,12 +3,15 @@ import numpy as np from fancy_rl import PPO import gymnasium as gym -@pytest.fixture def simple_env(): - return gym.make('CartPole-v1') + return gym.make('LunarLander-v2', continuous=True) def test_ppo_instantiation(): - ppo = PPO("CartPole-v1") + ppo = PPO(simple_env) + assert isinstance(ppo, PPO) + +def test_ppo_instantiation_from_str(): + ppo = PPO('CartPole-v1') assert isinstance(ppo, PPO) @pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3]) @@ -19,7 +22,7 @@ def test_ppo_instantiation(): @pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3]) def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range): ppo = PPO( - "CartPole-v1", + simple_env, learning_rate=learning_rate, n_steps=n_steps, batch_size=batch_size, @@ -34,26 +37,42 @@ def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_siz assert ppo.gamma == gamma assert ppo.clip_range == clip_range -def test_ppo_predict(simple_env): - ppo = PPO("CartPole-v1") - obs, _ = simple_env.reset() +def test_ppo_predict(): + ppo = PPO(simple_env) + env = ppo.make_env() + obs, _ = env.reset() action, _ = ppo.predict(obs) assert isinstance(action, np.ndarray) - assert action.shape == simple_env.action_space.shape + assert action.shape == env.action_space.shape def test_ppo_learn(): - ppo = PPO("CartPole-v1", n_steps=64, batch_size=32) - env = gym.make("CartPole-v1") + ppo = PPO(simple_env, n_steps=64, batch_size=32) + env = ppo.make_env() obs, _ = env.reset() for _ in range(64): action, _ = ppo.predict(obs) - next_obs, reward, done, truncated, _ = env.step(action) - ppo.store_transition(obs, action, reward, done, next_obs) - obs = next_obs + obs, reward, done, truncated, _ = env.step(action) if done or truncated: obs, _ = env.reset() + +def test_ppo_training(): + ppo = PPO(simple_env, total_timesteps=10000) + env = ppo.make_env() - loss = ppo.learn() - assert isinstance(loss, dict) - assert "policy_loss" in loss - assert "value_loss" in loss \ No newline at end of file + initial_performance = evaluate_policy(ppo, env) + ppo.train() + final_performance = evaluate_policy(ppo, env) + + assert final_performance > initial_performance, "PPO should improve performance after training" + +def evaluate_policy(policy, env, n_eval_episodes=10): + total_reward = 0 + for _ in range(n_eval_episodes): + obs, _ = env.reset() + done = False + while not done: + action, _ = policy.predict(obs) + obs, reward, terminated, truncated, _ = env.step(action) + total_reward += reward + done = terminated or truncated + return total_reward / n_eval_episodes \ No newline at end of file diff --git a/test/test_trpl.py b/test/test_trpl.py index c7f640b..d302c23 100644 --- a/test/test_trpl.py +++ b/test/test_trpl.py @@ -3,12 +3,15 @@ import numpy as np from fancy_rl import TRPL import gymnasium as gym -@pytest.fixture def simple_env(): - return gym.make('CartPole-v1') + return gym.make('LunarLander-v2', continuous=True) def test_trpl_instantiation(): - trpl = TRPL("CartPole-v1") + trpl = TRPL(simple_env) + assert isinstance(trpl, TRPL) + +def test_trpl_instantiation_from_str(): + trpl = TRPL('MountainCarContinuous-v0') assert isinstance(trpl, TRPL) @pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3]) @@ -19,7 +22,7 @@ def test_trpl_instantiation(): @pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001]) def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov): trpl = TRPL( - "CartPole-v1", + simple_env, learning_rate=learning_rate, n_steps=n_steps, batch_size=batch_size, @@ -34,16 +37,17 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov -def test_trpl_predict(simple_env): - trpl = TRPL("CartPole-v1") - obs, _ = simple_env.reset() +def test_trpl_predict(): + trpl = TRPL(simple_env) + env = trpl.make_env() + obs, _ = env.reset() action, _ = trpl.predict(obs) assert isinstance(action, np.ndarray) - assert action.shape == simple_env.action_space.shape + assert action.shape == env.action_space.shape def test_trpl_learn(): - trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32) - env = gym.make("CartPole-v1") + trpl = TRPL(simple_env, n_steps=64, batch_size=32) + env = trpl.make_env() obs, _ = env.reset() for _ in range(64): action, _ = trpl.predict(obs) @@ -58,12 +62,13 @@ def test_trpl_learn(): assert "policy_loss" in loss assert "value_loss" in loss -def test_trpl_training(simple_env): - trpl = TRPL("CartPole-v1", total_timesteps=10000) +def test_trpl_training(): + trpl = TRPL(simple_env, total_timesteps=10000) + env = trpl.make_env() - initial_performance = evaluate_policy(trpl, simple_env) + initial_performance = evaluate_policy(trpl, env) trpl.train() - final_performance = evaluate_policy(trpl, simple_env) + final_performance = evaluate_policy(trpl, env) assert final_performance > initial_performance, "TRPL should improve performance after training"