trl spec for .reset
This commit is contained in:
parent
4f8fc500b7
commit
e938018494
@ -2,9 +2,10 @@ import pytest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from fancy_rl import PPO
|
from fancy_rl import PPO
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
from torchrl.envs import GymEnv
|
||||||
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return gym.make('LunarLander-v2', continuous=True)
|
return GymEnv('LunarLander-v2', continuous=True)
|
||||||
|
|
||||||
def test_ppo_instantiation():
|
def test_ppo_instantiation():
|
||||||
ppo = PPO(simple_env)
|
ppo = PPO(simple_env)
|
||||||
@ -14,6 +15,10 @@ def test_ppo_instantiation_from_str():
|
|||||||
ppo = PPO('CartPole-v1')
|
ppo = PPO('CartPole-v1')
|
||||||
assert isinstance(ppo, PPO)
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
|
def test_ppo_instantiation_from_make():
|
||||||
|
ppo = PPO(gym.make('CartPole-v1'))
|
||||||
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||||
@ -48,12 +53,12 @@ def test_ppo_predict():
|
|||||||
def test_ppo_learn():
|
def test_ppo_learn():
|
||||||
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||||
env = ppo.make_env()
|
env = ppo.make_env()
|
||||||
obs, _ = env.reset()
|
obs = env.reset()
|
||||||
for _ in range(64):
|
for _ in range(64):
|
||||||
action, _ = ppo.predict(obs)
|
action, _next_state = ppo.predict(obs)
|
||||||
obs, reward, done, truncated, _ = env.step(action)
|
obs, reward, done, truncated, _ = env.step(action)
|
||||||
if done or truncated:
|
if done or truncated:
|
||||||
obs, _ = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
def test_ppo_training():
|
def test_ppo_training():
|
||||||
ppo = PPO(simple_env, total_timesteps=10000)
|
ppo = PPO(simple_env, total_timesteps=10000)
|
||||||
@ -68,10 +73,10 @@ def test_ppo_training():
|
|||||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
for _ in range(n_eval_episodes):
|
for _ in range(n_eval_episodes):
|
||||||
obs, _ = env.reset()
|
obs = env.reset()
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
action, _ = policy.predict(obs)
|
action, _next_state = policy.predict(obs)
|
||||||
obs, reward, terminated, truncated, _ = env.step(action)
|
obs, reward, terminated, truncated, _ = env.step(action)
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
done = terminated or truncated
|
done = terminated or truncated
|
||||||
|
Loading…
Reference in New Issue
Block a user