fancy_rl/test/test_trpl.py

75 lines
2.5 KiB
Python

import pytest
import numpy as np
from fancy_rl import TRPL
import gymnasium as gym
from tensordict import TensorDict
def simple_env():
return gym.make('Pendulum-v1')
def test_trpl_instantiation():
trpl = TRPL(simple_env)
assert isinstance(trpl, TRPL)
def test_trpl_instantiation_from_str():
trpl = TRPL('MountainCarContinuous-v0')
assert isinstance(trpl, TRPL)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("trust_region_bound_mean", [0.05, 0.1])
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
trpl = TRPL(
simple_env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
gamma=gamma,
trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
)
assert trpl.learning_rate == learning_rate
assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size
assert trpl.gamma == gamma
assert trpl.projection.mean_bound == trust_region_bound_mean
assert trpl.projection.cov_bound == trust_region_bound_cov
def test_trpl_predict():
trpl = TRPL(simple_env)
env = trpl.make_env()
obs = env.reset()
action = trpl.predict(obs)
assert isinstance(action, TensorDict)
# Handle both single and composite action spaces
if isinstance(env.action_space, list):
expected_shape = (len(env.action_space),) + env.action_space[0].shape
else:
expected_shape = env.action_space.shape
assert action["action"].shape == expected_shape
def test_trpl_training():
trpl = TRPL(simple_env, total_timesteps=100)
env = trpl.make_env()
initial_performance = evaluate_policy(trpl, env)
trpl.train()
final_performance = evaluate_policy(trpl, env)
def evaluate_policy(policy, env, n_eval_episodes=3):
total_reward = 0
for _ in range(n_eval_episodes):
tensordict = env.reset()
done = False
while not done:
action = policy.predict(tensordict)
next_tensordict = env.step(action).get("next")
total_reward += next_tensordict["reward"]
done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes