From e938018494a1a5b512eb2a40ae9c4d5155d81299 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 7 Nov 2024 11:41:01 +0100 Subject: [PATCH] trl spec for .reset --- test/test_ppo.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/test_ppo.py b/test/test_ppo.py index 68bddb1..0d4b96a 100644 --- a/test/test_ppo.py +++ b/test/test_ppo.py @@ -2,9 +2,10 @@ import pytest import numpy as np from fancy_rl import PPO import gymnasium as gym +from torchrl.envs import GymEnv def simple_env(): - return gym.make('LunarLander-v2', continuous=True) + return GymEnv('LunarLander-v2', continuous=True) def test_ppo_instantiation(): ppo = PPO(simple_env) @@ -14,6 +15,10 @@ def test_ppo_instantiation_from_str(): ppo = PPO('CartPole-v1') assert isinstance(ppo, PPO) +def test_ppo_instantiation_from_make(): + ppo = PPO(gym.make('CartPole-v1')) + assert isinstance(ppo, PPO) + @pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3]) @pytest.mark.parametrize("n_steps", [1024, 2048]) @pytest.mark.parametrize("batch_size", [32, 64, 128]) @@ -48,12 +53,12 @@ def test_ppo_predict(): def test_ppo_learn(): ppo = PPO(simple_env, n_steps=64, batch_size=32) env = ppo.make_env() - obs, _ = env.reset() + obs = env.reset() for _ in range(64): - action, _ = ppo.predict(obs) + action, _next_state = ppo.predict(obs) obs, reward, done, truncated, _ = env.step(action) if done or truncated: - obs, _ = env.reset() + obs = env.reset() def test_ppo_training(): ppo = PPO(simple_env, total_timesteps=10000) @@ -68,10 +73,10 @@ def test_ppo_training(): def evaluate_policy(policy, env, n_eval_episodes=10): total_reward = 0 for _ in range(n_eval_episodes): - obs, _ = env.reset() + obs = env.reset() done = False while not done: - action, _ = policy.predict(obs) + action, _next_state = policy.predict(obs) obs, reward, terminated, truncated, _ = env.step(action) total_reward += reward done = terminated or truncated