Added tests
This commit is contained in:
		
							parent
							
								
									39c21ab6b9
								
							
						
					
					
						commit
						8d4f57a59d
					
				
							
								
								
									
										9
									
								
								tests/test_ppo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								tests/test_ppo.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | import gymnasium as gym | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | from metastable_baselines2 import PPO | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"]) | ||||||
|  | def test_trpl(env_id): | ||||||
|  |     model = PPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) | ||||||
|  |     model.learn(total_timesteps=500) | ||||||
							
								
								
									
										20
									
								
								tests/test_trpl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								tests/test_trpl.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | import gymnasium as gym | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | from metastable_baselines2 import TRPL | ||||||
|  | 
 | ||||||
|  | PROJECTIONS = ["Frobenius", "Wasserstein"] #, "KL"] | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2", "MountainCarContinuous-v0"]) | ||||||
|  | @pytest.mark.parametrize("projection", PROJECTIONS) | ||||||
|  | def test_trpl(env_id, projection): | ||||||
|  |     model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, verbose=1) | ||||||
|  |     model.learn(total_timesteps=500) | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize("env_id", ["LunarLanderContinuous-v2"]) | ||||||
|  | @pytest.mark.parametrize("projection", PROJECTIONS) | ||||||
|  | @pytest.mark.parametrize("mean_bound", [0.03, 0.06]) | ||||||
|  | @pytest.mark.parametrize("cov_bound", [1.0e-3, 2.0e-3]) | ||||||
|  | def test_trpl_params(env_id, projection, mean_bound, cov_bound): | ||||||
|  |     model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1) | ||||||
|  |     model.learn(total_timesteps=100) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user