Allow manual early stopping of training (Ctrl+C)
This commit is contained in:
		
							parent
							
								
									e8d423f91f
								
							
						
					
					
						commit
						28561b9bb2
					
				
							
								
								
									
										70
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								test.py
									
									
									
									
									
								
							| @ -17,75 +17,39 @@ import columbus | |||||||
| root_path = '.' | root_path = '.' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=500000, showRes=True, saveModel=True, n_eval_episodes=0): | def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=200_000, showRes=True, saveModel=True, n_eval_episodes=0): | ||||||
|     env = gym.make(env_name) |     env = gym.make(env_name) | ||||||
|     test_sde = False |     use_sde = False | ||||||
|     ppo = PPO( |     ppo = PPO( | ||||||
|         "MlpPolicy", |         "MlpPolicy", | ||||||
|         env, |         env, | ||||||
|         verbose=0, |         verbose=0, | ||||||
|         tensorboard_log=root_path+"/logs_tb/"+env_name+"/ppo/", |         tensorboard_log=root_path+"/logs_tb/" + | ||||||
|  |         env_name+"/ppo"+(['', '_sde'][use_sde])+"/", | ||||||
|         learning_rate=3e-4, |         learning_rate=3e-4, | ||||||
|         gamma=0.99, |         gamma=0.99, | ||||||
|         gae_lambda=0.95, |         gae_lambda=0.95, | ||||||
|         normalize_advantage=True, |         normalize_advantage=True, | ||||||
|         ent_coef=0.1,  # 0.1 |         ent_coef=0.02,  # 0.1 | ||||||
|         vf_coef=0.5, |         vf_coef=0.5, | ||||||
|         use_sde=False,  # False |         use_sde=use_sde,  # False | ||||||
|  |         clip_range=0.2, | ||||||
|     ) |     ) | ||||||
|     trl_pg = TRL_PG( |     trl_pg = TRL_PG( | ||||||
|         "MlpPolicy", |         "MlpPolicy", | ||||||
|         env, |         env, | ||||||
|         verbose=0, |         verbose=0, | ||||||
|         tensorboard_log=root_path+"/logs_tb/"+env_name+"/trl_pg/", |         tensorboard_log=root_path+"/logs_tb/"+env_name + | ||||||
|  |         "/trl_pg"+(['', '_sde'][use_sde])+"/", | ||||||
|         learning_rate=3e-4, |         learning_rate=3e-4, | ||||||
|         gamma=0.99, |         gamma=0.99, | ||||||
|         gae_lambda=0.95, |         gae_lambda=0.95, | ||||||
|         normalize_advantage=True, |         normalize_advantage=True, | ||||||
|         ent_coef=0.1,  # 0.1 |         ent_coef=0.03,  # 0.1 | ||||||
|         vf_coef=0.5, |         vf_coef=0.5, | ||||||
|         use_sde=False,  # False |         use_sde=use_sde, | ||||||
|  |         clip_range=2,  # 0.2 | ||||||
|     ) |     ) | ||||||
|     if test_sde: |  | ||||||
|         ppo_latent_sde = PPO( |  | ||||||
|             "MlpPolicy", |  | ||||||
|             env, |  | ||||||
|             verbose=0, |  | ||||||
|             tensorboard_log=root_path+"/logs_tb/"+env_name+"/ppo_latent_sde/", |  | ||||||
|             learning_rate=3e-4, |  | ||||||
|             gamma=0.99, |  | ||||||
|             gae_lambda=0.95, |  | ||||||
|             normalize_advantage=True, |  | ||||||
|             ent_coef=0.15,  # 0.1 |  | ||||||
|             vf_coef=0.5, |  | ||||||
|             use_sde=True,  # False |  | ||||||
|             sde_sample_freq=30*15,  # -1 |  | ||||||
|         ) |  | ||||||
|         trl_pg_latent_sde = TRL_PG( |  | ||||||
|             "MlpPolicy", |  | ||||||
|             env, |  | ||||||
|             verbose=0, |  | ||||||
|             tensorboard_log=root_path+"/logs_tb/"+env_name+"/trl_pg_latent_sde/", |  | ||||||
|             learning_rate=3e-4, |  | ||||||
|             gamma=0.99, |  | ||||||
|             gae_lambda=0.95, |  | ||||||
|             normalize_advantage=True, |  | ||||||
|             ent_coef=0.15,  # 0.1 |  | ||||||
|             vf_coef=0.5, |  | ||||||
|             use_sde=True,  # False |  | ||||||
|             sde_sample_freq=30*15,  # -1 |  | ||||||
|         ) |  | ||||||
|     # sac_latent_sde = SAC( |  | ||||||
|     #    "MlpPolicy", |  | ||||||
|     #    env, |  | ||||||
|     #    verbose=0, |  | ||||||
|     #    tensorboard_log=root_path+"/logs_tb/"+env_name+"/sac_latent_sde/", |  | ||||||
|     #    use_sde=True, |  | ||||||
|     #    sde_sample_freq=30*15, |  | ||||||
|     #    ent_coef=0.0016, #0.0032 |  | ||||||
|     #    gamma=0.99, # 0.95 |  | ||||||
|     #    learning_rate=0.001 # 0.015 |  | ||||||
|     # ) |  | ||||||
| 
 | 
 | ||||||
|     print('TRL_PG:') |     print('TRL_PG:') | ||||||
|     testModel(trl_pg, timesteps, showRes, |     testModel(trl_pg, timesteps, showRes, | ||||||
| @ -97,13 +61,18 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=500000, showRes=True, | |||||||
| 
 | 
 | ||||||
| def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16): | def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16): | ||||||
|     env = model.get_env() |     env = model.get_env() | ||||||
|     model.learn(timesteps) |     try: | ||||||
|  |         model.learn(timesteps) | ||||||
|  |     except KeyboardInterrupt: | ||||||
|  |         print('[!] Training Terminated') | ||||||
|  |         pass | ||||||
| 
 | 
 | ||||||
|     if saveModel: |     if saveModel: | ||||||
|         now = datetime.datetime.now().strftime('%d.%m.%Y-%H:%M') |         now = datetime.datetime.now().strftime('%d.%m.%Y-%H:%M') | ||||||
|         loc = root_path+'/models/' + \ |         loc = root_path+'/models/' + \ | ||||||
|             model.tensorboard_log.replace( |             model.tensorboard_log.replace( | ||||||
|                 root_path+'/logs_tb/', '').replace('/', '_')+now+'.zip' |                 root_path+'/logs_tb/', '').replace('/', '_')+now+'.zip' | ||||||
|  |         print(model.get_parameters()) | ||||||
|         model.save(loc) |         model.save(loc) | ||||||
| 
 | 
 | ||||||
|     if n_eval_episodes: |     if n_eval_episodes: | ||||||
| @ -132,3 +101,6 @@ def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes= | |||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     main('LunarLanderContinuous-v2') |     main('LunarLanderContinuous-v2') | ||||||
|  |     #main('ColumbusJustState-v0') | ||||||
|  |     #main('ColumbusStateWithBarriers-v0') | ||||||
|  |     #main('ColumbusEasierObstacles-v0') | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user