fancy_rl/test/test_ppo.py

54 lines
1.5 KiB
Python

import pytest
import numpy as np
from fancy_rl import PPO
import gymnasium as gym
from torchrl.envs import GymEnv
import torch as th
from tensordict import TensorDict
def simple_env():
return gym.make('LunarLander-v2')
def test_ppo_instantiation():
ppo = PPO(simple_env)
assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO)
def test_ppo_predict():
ppo = PPO(simple_env)
env = ppo.make_env()
obs = env.reset()
action = ppo.predict(obs)
assert isinstance(action, TensorDict)
# Handle both single and composite action spaces
if isinstance(env.action_space, list):
expected_shape = (len(env.action_space),) + env.action_space[0].shape
else:
expected_shape = env.action_space.shape
assert action["action"].shape == expected_shape
def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=100)
env = ppo.make_env()
initial_performance = evaluate_policy(ppo, env)
ppo.train()
final_performance = evaluate_policy(ppo, env)
def evaluate_policy(policy, env, n_eval_episodes=3):
total_reward = 0
for _ in range(n_eval_episodes):
tensordict = env.reset()
done = False
while not done:
action = policy.predict(tensordict)
next_tensordict = env.step(action).get("next")
total_reward += next_tensordict["reward"]
done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes