fancy_rl/test/test_ppo.py

59 lines
1.8 KiB
Python
Raw Normal View History

2024-08-28 11:30:37 +02:00
import pytest
import numpy as np
from fancy_rl import PPO
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
def test_ppo_instantiation():
ppo = PPO("CartPole-v1")
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("n_epochs", [5, 10])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
"CartPole-v1",
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
clip_range=clip_range
)
assert ppo.learning_rate == learning_rate
assert ppo.n_steps == n_steps
assert ppo.batch_size == batch_size
assert ppo.n_epochs == n_epochs
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict(simple_env):
ppo = PPO("CartPole-v1")
obs, _ = simple_env.reset()
action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
def test_ppo_learn():
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
for _ in range(64):
action, _ = ppo.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
ppo.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = ppo.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss