Testing SACs ability to solve EasierObstacles-v0
This commit is contained in:
		
							parent
							
								
									13d335f856
								
							
						
					
					
						commit
						84b3710850
					
				
							
								
								
									
										31
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								test.py
									
									
									
									
									
								
							| @ -12,19 +12,30 @@ from sb3_trl.trl_pg import TRL_PG | |||||||
| import columbus | import columbus | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(env_name='ColumbusEasyObstacles-v0'): | def main(env_name='ColumbusEasierObstacles-v0'): | ||||||
|     env = gym.make(env_name) |     env = gym.make(env_name) | ||||||
|     ppo_latent_sde = PPO( |     ppo_latent_sde = PPO( | ||||||
|         "MlpPolicy", |         "MlpPolicy", | ||||||
|         env, |         env, | ||||||
|         verbose=1, |         verbose=0, | ||||||
|         tensorboard_log="./logs_tb/"+env_name+"/ppo_latent_sde/", |         tensorboard_log="./logs_tb/"+env_name+"/ppo_latent_sde/", | ||||||
|         use_sde=True, |         use_sde=True, | ||||||
|         sde_sample_freq=30*15, |         sde_sample_freq=30*15, | ||||||
|         ent_coef=0.0032, |         ent_coef=0.0016/1.25, #0.0032 | ||||||
|         vf_coef=0.0005, |         vf_coef=0.00025/2, #0.0005 | ||||||
|         gamma=0.95, |         gamma=0.99, # 0.95 | ||||||
|         learning_rate=0.02 |         learning_rate=0.005/5 # 0.015 | ||||||
|  |     ) | ||||||
|  |     sac_latent_sde = SAC( | ||||||
|  |         "MlpPolicy", | ||||||
|  |         env, | ||||||
|  |         verbose=0, | ||||||
|  |         tensorboard_log="./logs_tb/"+env_name+"/sac_latent_sde/", | ||||||
|  |         use_sde=True, | ||||||
|  |         sde_sample_freq=30*15, | ||||||
|  |         ent_coef=0.0016, #0.0032 | ||||||
|  |         gamma=0.99, # 0.95 | ||||||
|  |         learning_rate=0.001 # 0.015 | ||||||
|     ) |     ) | ||||||
|     #trl = TRL_PG( |     #trl = TRL_PG( | ||||||
|     #    "MlpPolicy", |     #    "MlpPolicy", | ||||||
| @ -33,8 +44,10 @@ def main(env_name='ColumbusEasyObstacles-v0'): | |||||||
|     #    tensorboard_log="./logs_tb/"+env_name+"/trl_pg/", |     #    tensorboard_log="./logs_tb/"+env_name+"/trl_pg/", | ||||||
|     #) |     #) | ||||||
| 
 | 
 | ||||||
|     print('PPO_LATENT_SDE:') |     #print('PPO_LATENT_SDE:') | ||||||
|     testModel(ppo_latent_sde, 100000, showRes = True, saveModel=True, n_eval_episodes=0) |     #testModel(ppo_latent_sde, 1000000, showRes = True, saveModel=True, n_eval_episodes=3) | ||||||
|  |     print('SAC_LATENT_SDE:') | ||||||
|  |     testModel(ppo_latent_sde, 250000, showRes = True, saveModel=True, n_eval_episodes=0) | ||||||
|     #print('TRL_PG:') |     #print('TRL_PG:') | ||||||
|     #testModel(trl) |     #testModel(trl) | ||||||
| 
 | 
 | ||||||
| @ -45,7 +58,7 @@ def testModel(model, timesteps=150000, showRes=False, saveModel=False, n_eval_ep | |||||||
| 
 | 
 | ||||||
|     if saveModel: |     if saveModel: | ||||||
|         now = datetime.datetime.now().strftime('%d.%m.%Y-%H:%M') |         now = datetime.datetime.now().strftime('%d.%m.%Y-%H:%M') | ||||||
|         model.save(model.tensorboard_log.replace('./logs_tb/','').replace('/','_')+now+'.zip') |         model.save('models/'+model.tensorboard_log.replace('./logs_tb/','').replace('/','_')+now+'.zip') | ||||||
| 
 | 
 | ||||||
|     if n_eval_episodes: |     if n_eval_episodes: | ||||||
|         mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False) |         mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user