55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
import pytest
|
|
import torch
|
|
from fancy_rl.ppo import PPO
|
|
from fancy_rl.policy import Policy
|
|
from fancy_rl.loggers import TerminalLogger
|
|
from fancy_rl.utils import make_env
|
|
|
|
@pytest.fixture
|
|
def policy():
|
|
return Policy(input_dim=4, output_dim=2, hidden_sizes=[64, 64])
|
|
|
|
@pytest.fixture
|
|
def loggers():
|
|
return [TerminalLogger()]
|
|
|
|
@pytest.fixture
|
|
def env_fn():
|
|
return make_env("CartPole-v1")
|
|
|
|
def test_ppo_train(policy, loggers, env_fn):
|
|
ppo = PPO(policy=policy,
|
|
env_fn=env_fn,
|
|
loggers=loggers,
|
|
learning_rate=3e-4,
|
|
n_steps=2048,
|
|
batch_size=64,
|
|
n_epochs=10,
|
|
gamma=0.99,
|
|
gae_lambda=0.95,
|
|
clip_range=0.2,
|
|
total_timesteps=10000,
|
|
eval_interval=2048,
|
|
eval_deterministic=True,
|
|
eval_episodes=5,
|
|
seed=42)
|
|
ppo.train()
|
|
|
|
def test_ppo_evaluate(policy, loggers, env_fn):
|
|
ppo = PPO(policy=policy,
|
|
env_fn=env_fn,
|
|
loggers=loggers,
|
|
learning_rate=3e-4,
|
|
n_steps=2048,
|
|
batch_size=64,
|
|
n_epochs=10,
|
|
gamma=0.99,
|
|
gae_lambda=0.95,
|
|
clip_range=0.2,
|
|
total_timesteps=10000,
|
|
eval_interval=2048,
|
|
eval_deterministic=True,
|
|
eval_episodes=5,
|
|
seed=42)
|
|
ppo.evaluate(epoch=0)
|