Testing the RayObserver
This commit is contained in:
		
							parent
							
								
									605a81c81c
								
							
						
					
					
						commit
						477a3c48b1
					
				
							
								
								
									
										57
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								test.py
									
									
									
									
									
								
							| @ -7,54 +7,71 @@ from stable_baselines3 import SAC, PPO, A2C | |||||||
| from stable_baselines3.common.evaluation import evaluate_policy | from stable_baselines3.common.evaluation import evaluate_policy | ||||||
| 
 | 
 | ||||||
| from sb3_trl.trl_pg import TRL_PG | from sb3_trl.trl_pg import TRL_PG | ||||||
| from subtrees.columbus import env | from columbus import env | ||||||
| 
 | 
 | ||||||
| register( | register( | ||||||
|     id='ColumbusTest3.1-v0', |     id='ColumbusTestRay-v0', | ||||||
|     entry_point=env.ColumbusTest3_1, |     entry_point=env.ColumbusTestRay, | ||||||
|     max_episode_steps=1000, |     max_episode_steps=30*60*5, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| def main(): | def main(): | ||||||
|     #env = gym.make("LunarLander-v2") |     #env = gym.make("LunarLander-v2") | ||||||
|     env = gym.make("ColumbusTest3.1-v0") |     env = gym.make("ColumbusTestRay-v0") | ||||||
| 
 | 
 | ||||||
|     ppo = PPO( |     ppo = PPO( | ||||||
|         "MlpPolicy", |         "MlpPolicy", | ||||||
|         env, |         env, | ||||||
|         verbose=0, |         verbose=1, | ||||||
|         tensorboard_log="./logs_tb/test/", |         tensorboard_log="./logs_tb/test/ppo", | ||||||
|  |         use_sde=False, | ||||||
|  |         ent_coef=0.0001, | ||||||
|  |         learning_rate=0.0004 | ||||||
|  |     ) | ||||||
|  |     ppo_sde = PPO( | ||||||
|  |         "MlpPolicy", | ||||||
|  |         env, | ||||||
|  |         verbose=1, | ||||||
|  |         tensorboard_log="./logs_tb/test/ppo_sde/", | ||||||
|  |         use_sde=True, | ||||||
|  |         sde_sample_freq=30*20, | ||||||
|  |         ent_coef=0.000001, | ||||||
|  |         learning_rate=0.0003 | ||||||
|     ) |     ) | ||||||
|     a2c = A2C( |     a2c = A2C( | ||||||
|         "MlpPolicy", |         "MlpPolicy", | ||||||
|         env, |         env, | ||||||
|         verbose=0, |         verbose=1, | ||||||
|         tensorboard_log="./logs_tb/test/", |         tensorboard_log="./logs_tb/test/a2c/", | ||||||
|     ) |     ) | ||||||
|     trl = TRL_PG( |     trl = TRL_PG( | ||||||
|         "MlpPolicy", |         "MlpPolicy", | ||||||
|         env, |         env, | ||||||
|         verbose=0, |         verbose=0, | ||||||
|         tensorboard_log="./logs_tb/test/", |         tensorboard_log="./logs_tb/test/trl_pg/", | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     print('PPO:') |     #print('PPO:') | ||||||
|     testModel(ppo) |     #testModel(ppo, 500000, showRes = True, saveModel=True, n_eval_episodes=4) | ||||||
|     print('A2C:') |     print('PPO_SDE:') | ||||||
|     testModel(a2c) |     testModel(ppo_sde, 100000, showRes = True, saveModel=True, n_eval_episodes=0) | ||||||
|     print('TRL_PG:') |     #print('A2C:') | ||||||
|     testModel(trl) |     #testModel(a2c, showRes = True) | ||||||
|  |     #print('TRL_PG:') | ||||||
|  |     #testModel(trl) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def testModel(model, timesteps=50000, showRes=False): | def testModel(model, timesteps=100000, showRes=False, saveModel=False, n_eval_episodes=16): | ||||||
|     env = model.get_env() |     env = model.get_env() | ||||||
|     model.learn(timesteps) |     model.learn(timesteps) | ||||||
| 
 | 
 | ||||||
|     mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=16, deterministic=False) |     if n_eval_episodes: | ||||||
| 
 |         mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False) | ||||||
|     print('Reward: '+str(round(mean_reward,3))+'±'+str(round(std_reward,2))) |         print('Reward: '+str(round(mean_reward,3))+'±'+str(round(std_reward,2))) | ||||||
| 
 | 
 | ||||||
|     if showRes: |     if showRes: | ||||||
|  |         model.save("model") | ||||||
|  |         input('<ready?>') | ||||||
|         obs = env.reset() |         obs = env.reset() | ||||||
|         # Evaluate the agent |         # Evaluate the agent | ||||||
|         episode_reward = 0 |         episode_reward = 0 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user