From 1706bea571d78e06d5df247b530e6bffe2521651 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 13 Jul 2022 19:39:09 +0200 Subject: [PATCH] Testing SDC --- test.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/test.py b/test.py index ad3dbdd..991229b 100755 --- a/test.py +++ b/test.py @@ -11,6 +11,8 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy from metastable_baselines.trl_pg import TRL_PG +from metastable_baselines.trl_pg.policies import MlpPolicy +from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer import columbus #root_path = os.getcwd() @@ -20,8 +22,8 @@ root_path = '.' def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0): env = gym.make(env_name) use_sde = False - ppo = PPO( - "MlpPolicy", + ppo = TRL_PG( + MlpPolicy, env, verbose=0, tensorboard_log=root_path+"/logs_tb/" + @@ -36,8 +38,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr clip_range=0.2, ) trl_pg = TRL_PG( - "MlpPolicy", + MlpPolicy, env, + projection=FrobeniusProjectionLayer(), verbose=0, tensorboard_log=root_path+"/logs_tb/"+env_name + "/trl_pg"+(['', '_sde'][use_sde])+"/", @@ -54,9 +57,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr print('TRL_PG:') testModel(trl_pg, timesteps, showRes, saveModel, n_eval_episodes) - # print('PPO:') - # testModel(ppo, timesteps, showRes, - # saveModel, n_eval_episodes) + print('PPO:') + testModel(ppo, timesteps, showRes, + saveModel, n_eval_episodes) def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):