diff --git a/test.py b/test.py index d458c41..4e308ea 100644 --- a/test.py +++ b/test.py @@ -28,11 +28,22 @@ def main(): ent_coef=0.0001, learning_rate=0.0004 ) - ppo_sde = PPO( + ppo_base_sde = PPO( "MlpPolicy", env, verbose=1, - tensorboard_log="./logs_tb/test/ppo_sde/", + tensorboard_log="./logs_tb/test/ppo_base_sde/", + use_sde=True, + sde_sample_freq=30*20, + sde_net_arch=[], + ent_coef=0.000001, + learning_rate=0.0003 + ) + ppo_latent_sde = PPO( + "MlpPolicy", + env, + verbose=1, + tensorboard_log="./logs_tb/test/ppo_latent_sde/", use_sde=True, sde_sample_freq=30*20, ent_coef=0.000001, @@ -53,8 +64,8 @@ def main(): #print('PPO:') #testModel(ppo, 500000, showRes = True, saveModel=True, n_eval_episodes=4) - print('PPO_SDE:') - testModel(ppo_sde, 100000, showRes = True, saveModel=True, n_eval_episodes=0) + print('PPO_BASE_SDE:') + testModel(ppo_base_sde, 200000, showRes = True, saveModel=True, n_eval_episodes=0) #print('A2C:') #testModel(a2c, showRes = True) #print('TRL_PG:')