54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
import pytest
|
|
import numpy as np
|
|
from fancy_rl import PPO
|
|
import gymnasium as gym
|
|
from torchrl.envs import GymEnv
|
|
import torch as th
|
|
from tensordict import TensorDict
|
|
|
|
def simple_env():
|
|
return gym.make('LunarLander-v2')
|
|
|
|
def test_ppo_instantiation():
|
|
ppo = PPO(simple_env)
|
|
assert isinstance(ppo, PPO)
|
|
|
|
def test_ppo_instantiation_from_str():
|
|
ppo = PPO('CartPole-v1')
|
|
assert isinstance(ppo, PPO)
|
|
|
|
def test_ppo_predict():
|
|
ppo = PPO(simple_env)
|
|
env = ppo.make_env()
|
|
obs = env.reset()
|
|
action = ppo.predict(obs)
|
|
assert isinstance(action, TensorDict)
|
|
|
|
# Handle both single and composite action spaces
|
|
if isinstance(env.action_space, list):
|
|
expected_shape = (len(env.action_space),) + env.action_space[0].shape
|
|
else:
|
|
expected_shape = env.action_space.shape
|
|
|
|
assert action["action"].shape == expected_shape
|
|
|
|
def test_ppo_training():
|
|
ppo = PPO(simple_env, total_timesteps=100)
|
|
env = ppo.make_env()
|
|
|
|
initial_performance = evaluate_policy(ppo, env)
|
|
ppo.train()
|
|
final_performance = evaluate_policy(ppo, env)
|
|
|
|
def evaluate_policy(policy, env, n_eval_episodes=3):
|
|
total_reward = 0
|
|
for _ in range(n_eval_episodes):
|
|
tensordict = env.reset()
|
|
done = False
|
|
while not done:
|
|
action = policy.predict(tensordict)
|
|
next_tensordict = env.step(action).get("next")
|
|
total_reward += next_tensordict["reward"]
|
|
done = next_tensordict["done"]
|
|
tensordict = next_tensordict
|
|
return total_reward / n_eval_episodes |