From 84b37108509381e41a1bd9a68e32d7395ac9156e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 21 Jun 2022 15:15:38 +0200 Subject: [PATCH] Testing SACs ability to solve EasierObstacles-v0 --- test.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/test.py b/test.py index 17accf5..1a3ee7d 100644 --- a/test.py +++ b/test.py @@ -12,19 +12,30 @@ from sb3_trl.trl_pg import TRL_PG import columbus -def main(env_name='ColumbusEasyObstacles-v0'): +def main(env_name='ColumbusEasierObstacles-v0'): env = gym.make(env_name) ppo_latent_sde = PPO( "MlpPolicy", env, - verbose=1, + verbose=0, tensorboard_log="./logs_tb/"+env_name+"/ppo_latent_sde/", use_sde=True, sde_sample_freq=30*15, - ent_coef=0.0032, - vf_coef=0.0005, - gamma=0.95, - learning_rate=0.02 + ent_coef=0.0016/1.25, #0.0032 + vf_coef=0.00025/2, #0.0005 + gamma=0.99, # 0.95 + learning_rate=0.005/5 # 0.015 + ) + sac_latent_sde = SAC( + "MlpPolicy", + env, + verbose=0, + tensorboard_log="./logs_tb/"+env_name+"/sac_latent_sde/", + use_sde=True, + sde_sample_freq=30*15, + ent_coef=0.0016, #0.0032 + gamma=0.99, # 0.95 + learning_rate=0.001 # 0.015 ) #trl = TRL_PG( # "MlpPolicy", @@ -33,8 +44,10 @@ def main(env_name='ColumbusEasyObstacles-v0'): # tensorboard_log="./logs_tb/"+env_name+"/trl_pg/", #) - print('PPO_LATENT_SDE:') - testModel(ppo_latent_sde, 100000, showRes = True, saveModel=True, n_eval_episodes=0) + #print('PPO_LATENT_SDE:') + #testModel(ppo_latent_sde, 1000000, showRes = True, saveModel=True, n_eval_episodes=3) + print('SAC_LATENT_SDE:') + testModel(ppo_latent_sde, 250000, showRes = True, saveModel=True, n_eval_episodes=0) #print('TRL_PG:') #testModel(trl) @@ -45,7 +58,7 @@ def testModel(model, timesteps=150000, showRes=False, saveModel=False, n_eval_ep if saveModel: now = datetime.datetime.now().strftime('%d.%m.%Y-%H:%M') - model.save(model.tensorboard_log.replace('./logs_tb/','').replace('/','_')+now+'.zip') + model.save('models/'+model.tensorboard_log.replace('./logs_tb/','').replace('/','_')+now+'.zip') if n_eval_episodes: mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False)