trl spec for .reset

This commit is contained in:
Dominik Moritz Roth 2024-11-07 11:41:01 +01:00
parent 4f8fc500b7
commit e938018494

View File

@ -2,9 +2,10 @@ import pytest
import numpy as np import numpy as np
from fancy_rl import PPO from fancy_rl import PPO
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv
def simple_env(): def simple_env():
return gym.make('LunarLander-v2', continuous=True) return GymEnv('LunarLander-v2', continuous=True)
def test_ppo_instantiation(): def test_ppo_instantiation():
ppo = PPO(simple_env) ppo = PPO(simple_env)
@ -14,6 +15,10 @@ def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1') ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO) assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_make():
ppo = PPO(gym.make('CartPole-v1'))
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3]) @pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048]) @pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128]) @pytest.mark.parametrize("batch_size", [32, 64, 128])
@ -48,12 +53,12 @@ def test_ppo_predict():
def test_ppo_learn(): def test_ppo_learn():
ppo = PPO(simple_env, n_steps=64, batch_size=32) ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = ppo.make_env() env = ppo.make_env()
obs, _ = env.reset() obs = env.reset()
for _ in range(64): for _ in range(64):
action, _ = ppo.predict(obs) action, _next_state = ppo.predict(obs)
obs, reward, done, truncated, _ = env.step(action) obs, reward, done, truncated, _ = env.step(action)
if done or truncated: if done or truncated:
obs, _ = env.reset() obs = env.reset()
def test_ppo_training(): def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=10000) ppo = PPO(simple_env, total_timesteps=10000)
@ -68,10 +73,10 @@ def test_ppo_training():
def evaluate_policy(policy, env, n_eval_episodes=10): def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0 total_reward = 0
for _ in range(n_eval_episodes): for _ in range(n_eval_episodes):
obs, _ = env.reset() obs = env.reset()
done = False done = False
while not done: while not done:
action, _ = policy.predict(obs) action, _next_state = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action) obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward total_reward += reward
done = terminated or truncated done = terminated or truncated