streamline tests, mitigate broken env binding

This commit is contained in:
Dominik Moritz Roth 2025-01-22 13:46:16 +01:00
parent 04be117a95
commit 3816adef9a
2 changed files with 43 additions and 82 deletions

View File

@ -3,9 +3,11 @@ import numpy as np
from fancy_rl import PPO
import gymnasium as gym
from torchrl.envs import GymEnv
import torch as th
from tensordict import TensorDict
def simple_env():
return GymEnv('LunarLander-v2', continuous=True)
return gym.make('LunarLander-v2')
def test_ppo_instantiation():
ppo = PPO(simple_env)
@ -15,69 +17,38 @@ def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_make():
ppo = PPO(gym.make('CartPole-v1'))
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("n_epochs", [5, 10])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
simple_env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
clip_range=clip_range
)
assert ppo.learning_rate == learning_rate
assert ppo.n_steps == n_steps
assert ppo.batch_size == batch_size
assert ppo.n_epochs == n_epochs
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict():
ppo = PPO(simple_env)
env = ppo.make_env()
obs, _ = env.reset()
action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
obs = env.reset()
action = ppo.predict(obs)
assert isinstance(action, TensorDict)
def test_ppo_learn():
ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = ppo.make_env()
obs = env.reset()
for _ in range(64):
action, _next_state = ppo.predict(obs)
obs, reward, done, truncated, _ = env.step(action)
if done or truncated:
obs = env.reset()
# Handle both single and composite action spaces
if isinstance(env.action_space, list):
expected_shape = (len(env.action_space),) + env.action_space[0].shape
else:
expected_shape = env.action_space.shape
assert action["action"].shape == expected_shape
def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=10000)
ppo = PPO(simple_env, total_timesteps=100)
env = ppo.make_env()
initial_performance = evaluate_policy(ppo, env)
ppo.train()
final_performance = evaluate_policy(ppo, env)
assert final_performance > initial_performance, "PPO should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
def evaluate_policy(policy, env, n_eval_episodes=3):
total_reward = 0
for _ in range(n_eval_episodes):
obs = env.reset()
tensordict = env.reset()
done = False
while not done:
action, _next_state = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
action = policy.predict(tensordict)
next_tensordict = env.step(action).get("next")
total_reward += next_tensordict["reward"]
done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes

View File

@ -2,9 +2,10 @@ import pytest
import numpy as np
from fancy_rl import TRPL
import gymnasium as gym
from tensordict import TensorDict
def simple_env():
return gym.make('LunarLander-v2', continuous=True)
return gym.make('Pendulum-v1')
def test_trpl_instantiation():
trpl = TRPL(simple_env)
@ -34,52 +35,41 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size
assert trpl.gamma == gamma
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
assert trpl.projection.mean_bound == trust_region_bound_mean
assert trpl.projection.cov_bound == trust_region_bound_cov
def test_trpl_predict():
trpl = TRPL(simple_env)
env = trpl.make_env()
obs, _ = env.reset()
action, _ = trpl.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
obs = env.reset()
action = trpl.predict(obs)
assert isinstance(action, TensorDict)
def test_trpl_learn():
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
env = trpl.make_env()
obs, _ = env.reset()
for _ in range(64):
action, _ = trpl.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
trpl.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
# Handle both single and composite action spaces
if isinstance(env.action_space, list):
expected_shape = (len(env.action_space),) + env.action_space[0].shape
else:
expected_shape = env.action_space.shape
loss = trpl.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
assert action["action"].shape == expected_shape
def test_trpl_training():
trpl = TRPL(simple_env, total_timesteps=10000)
trpl = TRPL(simple_env, total_timesteps=100)
env = trpl.make_env()
initial_performance = evaluate_policy(trpl, env)
trpl.train()
final_performance = evaluate_policy(trpl, env)
assert final_performance > initial_performance, "TRPL should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
def evaluate_policy(policy, env, n_eval_episodes=3):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
tensordict = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
action = policy.predict(tensordict)
next_tensordict = env.step(action).get("next")
total_reward += next_tensordict["reward"]
done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes