Implemenetd new test suite
This commit is contained in:
parent
416c2036a5
commit
d29417187f
@ -1 +1,59 @@
|
||||
# TODO
|
||||
import pytest
|
||||
import numpy as np
|
||||
from fancy_rl import PPO
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_ppo_instantiation():
|
||||
ppo = PPO("CartPole-v1")
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n_epochs", [5, 10])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
||||
ppo = PPO(
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
n_epochs=n_epochs,
|
||||
gamma=gamma,
|
||||
clip_range=clip_range
|
||||
)
|
||||
assert ppo.learning_rate == learning_rate
|
||||
assert ppo.n_steps == n_steps
|
||||
assert ppo.batch_size == batch_size
|
||||
assert ppo.n_epochs == n_epochs
|
||||
assert ppo.gamma == gamma
|
||||
assert ppo.clip_range == clip_range
|
||||
|
||||
def test_ppo_predict(simple_env):
|
||||
ppo = PPO("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = ppo.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_ppo_learn():
|
||||
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = ppo.predict(obs)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
ppo.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
loss = ppo.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
77
test/test_trpl.py
Normal file
77
test/test_trpl.py
Normal file
@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from fancy_rl import TRPL
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_trpl_instantiation():
|
||||
trpl = TRPL("CartPole-v1")
|
||||
assert isinstance(trpl, TRPL)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("max_kl", [0.01, 0.05])
|
||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl):
|
||||
trpl = TRPL(
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
gamma=gamma,
|
||||
max_kl=max_kl
|
||||
)
|
||||
assert trpl.learning_rate == learning_rate
|
||||
assert trpl.n_steps == n_steps
|
||||
assert trpl.batch_size == batch_size
|
||||
assert trpl.gamma == gamma
|
||||
assert trpl.max_kl == max_kl
|
||||
|
||||
def test_trpl_predict(simple_env):
|
||||
trpl = TRPL("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = trpl.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_trpl_learn():
|
||||
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = trpl.predict(obs)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
trpl.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
loss = trpl.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
||||
|
||||
def test_trpl_training(simple_env):
|
||||
trpl = TRPL("CartPole-v1", total_timesteps=10000)
|
||||
|
||||
initial_performance = evaluate_policy(trpl, simple_env)
|
||||
trpl.train()
|
||||
final_performance = evaluate_policy(trpl, simple_env)
|
||||
|
||||
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
||||
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _ = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
return total_reward / n_eval_episodes
|
81
test/test_vlearn.py
Normal file
81
test/test_vlearn.py
Normal file
@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from fancy_rl import VLEARN
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_vlearn_instantiation():
|
||||
vlearn = VLEARN("CartPole-v1")
|
||||
assert isinstance(vlearn, VLEARN)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("mean_bound", [0.05, 0.1])
|
||||
@pytest.mark.parametrize("cov_bound", [0.0005, 0.001])
|
||||
def test_vlearn_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, mean_bound, cov_bound):
|
||||
vlearn = VLEARN(
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
gamma=gamma,
|
||||
mean_bound=mean_bound,
|
||||
cov_bound=cov_bound
|
||||
)
|
||||
assert vlearn.learning_rate == learning_rate
|
||||
assert vlearn.n_steps == n_steps
|
||||
assert vlearn.batch_size == batch_size
|
||||
assert vlearn.gamma == gamma
|
||||
assert vlearn.mean_bound == mean_bound
|
||||
assert vlearn.cov_bound == cov_bound
|
||||
|
||||
def test_vlearn_predict(simple_env):
|
||||
vlearn = VLEARN("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = vlearn.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_vlearn_learn():
|
||||
vlearn = VLEARN("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = vlearn.predict(obs)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
vlearn.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
loss = vlearn.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
||||
|
||||
def test_vlearn_training(simple_env):
|
||||
vlearn = VLEARN("CartPole-v1", total_timesteps=10000)
|
||||
|
||||
initial_performance = evaluate_policy(vlearn, simple_env)
|
||||
vlearn.train()
|
||||
final_performance = evaluate_policy(vlearn, simple_env)
|
||||
|
||||
assert final_performance > initial_performance, "VLearn should improve performance after training"
|
||||
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _ = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
return total_reward / n_eval_episodes
|
Loading…
Reference in New Issue
Block a user