trl spec for .reset
This commit is contained in:
parent
4f8fc500b7
commit
e938018494
@ -2,9 +2,10 @@ import pytest
|
||||
import numpy as np
|
||||
from fancy_rl import PPO
|
||||
import gymnasium as gym
|
||||
from torchrl.envs import GymEnv
|
||||
|
||||
def simple_env():
|
||||
return gym.make('LunarLander-v2', continuous=True)
|
||||
return GymEnv('LunarLander-v2', continuous=True)
|
||||
|
||||
def test_ppo_instantiation():
|
||||
ppo = PPO(simple_env)
|
||||
@ -14,6 +15,10 @@ def test_ppo_instantiation_from_str():
|
||||
ppo = PPO('CartPole-v1')
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
def test_ppo_instantiation_from_make():
|
||||
ppo = PPO(gym.make('CartPole-v1'))
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@ -48,12 +53,12 @@ def test_ppo_predict():
|
||||
def test_ppo_learn():
|
||||
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||
env = ppo.make_env()
|
||||
obs, _ = env.reset()
|
||||
obs = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = ppo.predict(obs)
|
||||
action, _next_state = ppo.predict(obs)
|
||||
obs, reward, done, truncated, _ = env.step(action)
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
obs = env.reset()
|
||||
|
||||
def test_ppo_training():
|
||||
ppo = PPO(simple_env, total_timesteps=10000)
|
||||
@ -68,10 +73,10 @@ def test_ppo_training():
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs, _ = env.reset()
|
||||
obs = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _ = policy.predict(obs)
|
||||
action, _next_state = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
|
Loading…
Reference in New Issue
Block a user