Allow manual early stopping of training (Ctrl+C)
This commit is contained in:
		
							parent
							
								
									e8d423f91f
								
							
						
					
					
						commit
						28561b9bb2
					
				
							
								
								
									
										70
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								test.py
									
									
									
									
									
								
							| @ -17,75 +17,39 @@ import columbus | ||||
| root_path = '.' | ||||
| 
 | ||||
| 
 | ||||
| def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=500000, showRes=True, saveModel=True, n_eval_episodes=0): | ||||
| def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=200_000, showRes=True, saveModel=True, n_eval_episodes=0): | ||||
|     env = gym.make(env_name) | ||||
|     test_sde = False | ||||
|     use_sde = False | ||||
|     ppo = PPO( | ||||
|         "MlpPolicy", | ||||
|         env, | ||||
|         verbose=0, | ||||
|         tensorboard_log=root_path+"/logs_tb/"+env_name+"/ppo/", | ||||
|         tensorboard_log=root_path+"/logs_tb/" + | ||||
|         env_name+"/ppo"+(['', '_sde'][use_sde])+"/", | ||||
|         learning_rate=3e-4, | ||||
|         gamma=0.99, | ||||
|         gae_lambda=0.95, | ||||
|         normalize_advantage=True, | ||||
|         ent_coef=0.1,  # 0.1 | ||||
|         ent_coef=0.02,  # 0.1 | ||||
|         vf_coef=0.5, | ||||
|         use_sde=False,  # False | ||||
|         use_sde=use_sde,  # False | ||||
|         clip_range=0.2, | ||||
|     ) | ||||
|     trl_pg = TRL_PG( | ||||
|         "MlpPolicy", | ||||
|         env, | ||||
|         verbose=0, | ||||
|         tensorboard_log=root_path+"/logs_tb/"+env_name+"/trl_pg/", | ||||
|         tensorboard_log=root_path+"/logs_tb/"+env_name + | ||||
|         "/trl_pg"+(['', '_sde'][use_sde])+"/", | ||||
|         learning_rate=3e-4, | ||||
|         gamma=0.99, | ||||
|         gae_lambda=0.95, | ||||
|         normalize_advantage=True, | ||||
|         ent_coef=0.1,  # 0.1 | ||||
|         ent_coef=0.03,  # 0.1 | ||||
|         vf_coef=0.5, | ||||
|         use_sde=False,  # False | ||||
|         use_sde=use_sde, | ||||
|         clip_range=2,  # 0.2 | ||||
|     ) | ||||
|     if test_sde: | ||||
|         ppo_latent_sde = PPO( | ||||
|             "MlpPolicy", | ||||
|             env, | ||||
|             verbose=0, | ||||
|             tensorboard_log=root_path+"/logs_tb/"+env_name+"/ppo_latent_sde/", | ||||
|             learning_rate=3e-4, | ||||
|             gamma=0.99, | ||||
|             gae_lambda=0.95, | ||||
|             normalize_advantage=True, | ||||
|             ent_coef=0.15,  # 0.1 | ||||
|             vf_coef=0.5, | ||||
|             use_sde=True,  # False | ||||
|             sde_sample_freq=30*15,  # -1 | ||||
|         ) | ||||
|         trl_pg_latent_sde = TRL_PG( | ||||
|             "MlpPolicy", | ||||
|             env, | ||||
|             verbose=0, | ||||
|             tensorboard_log=root_path+"/logs_tb/"+env_name+"/trl_pg_latent_sde/", | ||||
|             learning_rate=3e-4, | ||||
|             gamma=0.99, | ||||
|             gae_lambda=0.95, | ||||
|             normalize_advantage=True, | ||||
|             ent_coef=0.15,  # 0.1 | ||||
|             vf_coef=0.5, | ||||
|             use_sde=True,  # False | ||||
|             sde_sample_freq=30*15,  # -1 | ||||
|         ) | ||||
|     # sac_latent_sde = SAC( | ||||
|     #    "MlpPolicy", | ||||
|     #    env, | ||||
|     #    verbose=0, | ||||
|     #    tensorboard_log=root_path+"/logs_tb/"+env_name+"/sac_latent_sde/", | ||||
|     #    use_sde=True, | ||||
|     #    sde_sample_freq=30*15, | ||||
|     #    ent_coef=0.0016, #0.0032 | ||||
|     #    gamma=0.99, # 0.95 | ||||
|     #    learning_rate=0.001 # 0.015 | ||||
|     # ) | ||||
| 
 | ||||
|     print('TRL_PG:') | ||||
|     testModel(trl_pg, timesteps, showRes, | ||||
| @ -97,13 +61,18 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=500000, showRes=True, | ||||
| 
 | ||||
| def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16): | ||||
|     env = model.get_env() | ||||
|     model.learn(timesteps) | ||||
|     try: | ||||
|         model.learn(timesteps) | ||||
|     except KeyboardInterrupt: | ||||
|         print('[!] Training Terminated') | ||||
|         pass | ||||
| 
 | ||||
|     if saveModel: | ||||
|         now = datetime.datetime.now().strftime('%d.%m.%Y-%H:%M') | ||||
|         loc = root_path+'/models/' + \ | ||||
|             model.tensorboard_log.replace( | ||||
|                 root_path+'/logs_tb/', '').replace('/', '_')+now+'.zip' | ||||
|         print(model.get_parameters()) | ||||
|         model.save(loc) | ||||
| 
 | ||||
|     if n_eval_episodes: | ||||
| @ -132,3 +101,6 @@ def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes= | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     main('LunarLanderContinuous-v2') | ||||
|     #main('ColumbusJustState-v0') | ||||
|     #main('ColumbusStateWithBarriers-v0') | ||||
|     #main('ColumbusEasierObstacles-v0') | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user