Expand Tests
This commit is contained in:
		
							parent
							
								
									e927afcc30
								
							
						
					
					
						commit
						abc8dcbda1
					
				@ -3,12 +3,15 @@ import numpy as np
 | 
				
			|||||||
from fancy_rl import PPO
 | 
					from fancy_rl import PPO
 | 
				
			||||||
import gymnasium as gym
 | 
					import gymnasium as gym
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					 | 
				
			||||||
def simple_env():
 | 
					def simple_env():
 | 
				
			||||||
    return gym.make('CartPole-v1')
 | 
					    return gym.make('LunarLander-v2', continuous=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_ppo_instantiation():
 | 
					def test_ppo_instantiation():
 | 
				
			||||||
    ppo = PPO("CartPole-v1")
 | 
					    ppo = PPO(simple_env)
 | 
				
			||||||
 | 
					    assert isinstance(ppo, PPO)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_ppo_instantiation_from_str():
 | 
				
			||||||
 | 
					    ppo = PPO('CartPole-v1')
 | 
				
			||||||
    assert isinstance(ppo, PPO)
 | 
					    assert isinstance(ppo, PPO)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
 | 
					@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
 | 
				
			||||||
@ -19,7 +22,7 @@ def test_ppo_instantiation():
 | 
				
			|||||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
 | 
					@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
 | 
				
			||||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
 | 
					def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
 | 
				
			||||||
    ppo = PPO(
 | 
					    ppo = PPO(
 | 
				
			||||||
        "CartPole-v1",
 | 
					        simple_env,
 | 
				
			||||||
        learning_rate=learning_rate,
 | 
					        learning_rate=learning_rate,
 | 
				
			||||||
        n_steps=n_steps,
 | 
					        n_steps=n_steps,
 | 
				
			||||||
        batch_size=batch_size,
 | 
					        batch_size=batch_size,
 | 
				
			||||||
@ -34,26 +37,42 @@ def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_siz
 | 
				
			|||||||
    assert ppo.gamma == gamma
 | 
					    assert ppo.gamma == gamma
 | 
				
			||||||
    assert ppo.clip_range == clip_range
 | 
					    assert ppo.clip_range == clip_range
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_ppo_predict(simple_env):
 | 
					def test_ppo_predict():
 | 
				
			||||||
    ppo = PPO("CartPole-v1")
 | 
					    ppo = PPO(simple_env)
 | 
				
			||||||
    obs, _ = simple_env.reset()
 | 
					    env = ppo.make_env()
 | 
				
			||||||
 | 
					    obs, _ = env.reset()
 | 
				
			||||||
    action, _ = ppo.predict(obs)
 | 
					    action, _ = ppo.predict(obs)
 | 
				
			||||||
    assert isinstance(action, np.ndarray)
 | 
					    assert isinstance(action, np.ndarray)
 | 
				
			||||||
    assert action.shape == simple_env.action_space.shape
 | 
					    assert action.shape == env.action_space.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_ppo_learn():
 | 
					def test_ppo_learn():
 | 
				
			||||||
    ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
 | 
					    ppo = PPO(simple_env, n_steps=64, batch_size=32)
 | 
				
			||||||
    env = gym.make("CartPole-v1")
 | 
					    env = ppo.make_env()
 | 
				
			||||||
    obs, _ = env.reset()
 | 
					    obs, _ = env.reset()
 | 
				
			||||||
    for _ in range(64):
 | 
					    for _ in range(64):
 | 
				
			||||||
        action, _ = ppo.predict(obs)
 | 
					        action, _ = ppo.predict(obs)
 | 
				
			||||||
        next_obs, reward, done, truncated, _ = env.step(action)
 | 
					        obs, reward, done, truncated, _ = env.step(action)
 | 
				
			||||||
        ppo.store_transition(obs, action, reward, done, next_obs)
 | 
					 | 
				
			||||||
        obs = next_obs
 | 
					 | 
				
			||||||
        if done or truncated:
 | 
					        if done or truncated:
 | 
				
			||||||
            obs, _ = env.reset()
 | 
					            obs, _ = env.reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    loss = ppo.learn()
 | 
					def test_ppo_training():
 | 
				
			||||||
    assert isinstance(loss, dict)
 | 
					    ppo = PPO(simple_env, total_timesteps=10000)
 | 
				
			||||||
    assert "policy_loss" in loss
 | 
					    env = ppo.make_env()
 | 
				
			||||||
    assert "value_loss" in loss
 | 
					    
 | 
				
			||||||
 | 
					    initial_performance = evaluate_policy(ppo, env)
 | 
				
			||||||
 | 
					    ppo.train()
 | 
				
			||||||
 | 
					    final_performance = evaluate_policy(ppo, env)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    assert final_performance > initial_performance, "PPO should improve performance after training"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def evaluate_policy(policy, env, n_eval_episodes=10):
 | 
				
			||||||
 | 
					    total_reward = 0
 | 
				
			||||||
 | 
					    for _ in range(n_eval_episodes):
 | 
				
			||||||
 | 
					        obs, _ = env.reset()
 | 
				
			||||||
 | 
					        done = False
 | 
				
			||||||
 | 
					        while not done:
 | 
				
			||||||
 | 
					            action, _ = policy.predict(obs)
 | 
				
			||||||
 | 
					            obs, reward, terminated, truncated, _ = env.step(action)
 | 
				
			||||||
 | 
					            total_reward += reward
 | 
				
			||||||
 | 
					            done = terminated or truncated
 | 
				
			||||||
 | 
					    return total_reward / n_eval_episodes
 | 
				
			||||||
@ -3,12 +3,15 @@ import numpy as np
 | 
				
			|||||||
from fancy_rl import TRPL
 | 
					from fancy_rl import TRPL
 | 
				
			||||||
import gymnasium as gym
 | 
					import gymnasium as gym
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					 | 
				
			||||||
def simple_env():
 | 
					def simple_env():
 | 
				
			||||||
    return gym.make('CartPole-v1')
 | 
					    return gym.make('LunarLander-v2', continuous=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_trpl_instantiation():
 | 
					def test_trpl_instantiation():
 | 
				
			||||||
    trpl = TRPL("CartPole-v1")
 | 
					    trpl = TRPL(simple_env)
 | 
				
			||||||
 | 
					    assert isinstance(trpl, TRPL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_trpl_instantiation_from_str():
 | 
				
			||||||
 | 
					    trpl = TRPL('MountainCarContinuous-v0')
 | 
				
			||||||
    assert isinstance(trpl, TRPL)
 | 
					    assert isinstance(trpl, TRPL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
 | 
					@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
 | 
				
			||||||
@ -19,7 +22,7 @@ def test_trpl_instantiation():
 | 
				
			|||||||
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
 | 
					@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
 | 
				
			||||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
 | 
					def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
 | 
				
			||||||
    trpl = TRPL(
 | 
					    trpl = TRPL(
 | 
				
			||||||
        "CartPole-v1",
 | 
					        simple_env,
 | 
				
			||||||
        learning_rate=learning_rate,
 | 
					        learning_rate=learning_rate,
 | 
				
			||||||
        n_steps=n_steps,
 | 
					        n_steps=n_steps,
 | 
				
			||||||
        batch_size=batch_size,
 | 
					        batch_size=batch_size,
 | 
				
			||||||
@ -34,16 +37,17 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
 | 
				
			|||||||
    assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
 | 
					    assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
 | 
				
			||||||
    assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
 | 
					    assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_trpl_predict(simple_env):
 | 
					def test_trpl_predict():
 | 
				
			||||||
    trpl = TRPL("CartPole-v1")
 | 
					    trpl = TRPL(simple_env)
 | 
				
			||||||
    obs, _ = simple_env.reset()
 | 
					    env = trpl.make_env()
 | 
				
			||||||
 | 
					    obs, _ = env.reset()
 | 
				
			||||||
    action, _ = trpl.predict(obs)
 | 
					    action, _ = trpl.predict(obs)
 | 
				
			||||||
    assert isinstance(action, np.ndarray)
 | 
					    assert isinstance(action, np.ndarray)
 | 
				
			||||||
    assert action.shape == simple_env.action_space.shape
 | 
					    assert action.shape == env.action_space.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_trpl_learn():
 | 
					def test_trpl_learn():
 | 
				
			||||||
    trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
 | 
					    trpl = TRPL(simple_env, n_steps=64, batch_size=32)
 | 
				
			||||||
    env = gym.make("CartPole-v1")
 | 
					    env = trpl.make_env()
 | 
				
			||||||
    obs, _ = env.reset()
 | 
					    obs, _ = env.reset()
 | 
				
			||||||
    for _ in range(64):
 | 
					    for _ in range(64):
 | 
				
			||||||
        action, _ = trpl.predict(obs)
 | 
					        action, _ = trpl.predict(obs)
 | 
				
			||||||
@ -58,12 +62,13 @@ def test_trpl_learn():
 | 
				
			|||||||
    assert "policy_loss" in loss
 | 
					    assert "policy_loss" in loss
 | 
				
			||||||
    assert "value_loss" in loss
 | 
					    assert "value_loss" in loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_trpl_training(simple_env):
 | 
					def test_trpl_training():
 | 
				
			||||||
    trpl = TRPL("CartPole-v1", total_timesteps=10000)
 | 
					    trpl = TRPL(simple_env, total_timesteps=10000)
 | 
				
			||||||
 | 
					    env = trpl.make_env()
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    initial_performance = evaluate_policy(trpl, simple_env)
 | 
					    initial_performance = evaluate_policy(trpl, env)
 | 
				
			||||||
    trpl.train()
 | 
					    trpl.train()
 | 
				
			||||||
    final_performance = evaluate_policy(trpl, simple_env)
 | 
					    final_performance = evaluate_policy(trpl, env)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    assert final_performance > initial_performance, "TRPL should improve performance after training"
 | 
					    assert final_performance > initial_performance, "TRPL should improve performance after training"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user