Testing SACs ability to solve EasierObstacles-v0
This commit is contained in:
		
							parent
							
								
									13d335f856
								
							
						
					
					
						commit
						84b3710850
					
				
							
								
								
									
										31
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								test.py
									
									
									
									
									
								
							| @ -12,19 +12,30 @@ from sb3_trl.trl_pg import TRL_PG | ||||
| import columbus | ||||
| 
 | ||||
| 
 | ||||
| def main(env_name='ColumbusEasyObstacles-v0'): | ||||
| def main(env_name='ColumbusEasierObstacles-v0'): | ||||
|     env = gym.make(env_name) | ||||
|     ppo_latent_sde = PPO( | ||||
|         "MlpPolicy", | ||||
|         env, | ||||
|         verbose=1, | ||||
|         verbose=0, | ||||
|         tensorboard_log="./logs_tb/"+env_name+"/ppo_latent_sde/", | ||||
|         use_sde=True, | ||||
|         sde_sample_freq=30*15, | ||||
|         ent_coef=0.0032, | ||||
|         vf_coef=0.0005, | ||||
|         gamma=0.95, | ||||
|         learning_rate=0.02 | ||||
|         ent_coef=0.0016/1.25, #0.0032 | ||||
|         vf_coef=0.00025/2, #0.0005 | ||||
|         gamma=0.99, # 0.95 | ||||
|         learning_rate=0.005/5 # 0.015 | ||||
|     ) | ||||
|     sac_latent_sde = SAC( | ||||
|         "MlpPolicy", | ||||
|         env, | ||||
|         verbose=0, | ||||
|         tensorboard_log="./logs_tb/"+env_name+"/sac_latent_sde/", | ||||
|         use_sde=True, | ||||
|         sde_sample_freq=30*15, | ||||
|         ent_coef=0.0016, #0.0032 | ||||
|         gamma=0.99, # 0.95 | ||||
|         learning_rate=0.001 # 0.015 | ||||
|     ) | ||||
|     #trl = TRL_PG( | ||||
|     #    "MlpPolicy", | ||||
| @ -33,8 +44,10 @@ def main(env_name='ColumbusEasyObstacles-v0'): | ||||
|     #    tensorboard_log="./logs_tb/"+env_name+"/trl_pg/", | ||||
|     #) | ||||
| 
 | ||||
|     print('PPO_LATENT_SDE:') | ||||
|     testModel(ppo_latent_sde, 100000, showRes = True, saveModel=True, n_eval_episodes=0) | ||||
|     #print('PPO_LATENT_SDE:') | ||||
|     #testModel(ppo_latent_sde, 1000000, showRes = True, saveModel=True, n_eval_episodes=3) | ||||
|     print('SAC_LATENT_SDE:') | ||||
|     testModel(ppo_latent_sde, 250000, showRes = True, saveModel=True, n_eval_episodes=0) | ||||
|     #print('TRL_PG:') | ||||
|     #testModel(trl) | ||||
| 
 | ||||
| @ -45,7 +58,7 @@ def testModel(model, timesteps=150000, showRes=False, saveModel=False, n_eval_ep | ||||
| 
 | ||||
|     if saveModel: | ||||
|         now = datetime.datetime.now().strftime('%d.%m.%Y-%H:%M') | ||||
|         model.save(model.tensorboard_log.replace('./logs_tb/','').replace('/','_')+now+'.zip') | ||||
|         model.save('models/'+model.tensorboard_log.replace('./logs_tb/','').replace('/','_')+now+'.zip') | ||||
| 
 | ||||
|     if n_eval_episodes: | ||||
|         mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user