Testing SDC

This commit is contained in:
Dominik Moritz Roth 2022-07-13 19:39:09 +02:00
parent 3304fd49f6
commit 1706bea571

15
test.py
View File

@ -11,6 +11,8 @@ from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
from metastable_baselines.trl_pg import TRL_PG from metastable_baselines.trl_pg import TRL_PG
from metastable_baselines.trl_pg.policies import MlpPolicy
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
import columbus import columbus
#root_path = os.getcwd() #root_path = os.getcwd()
@ -20,8 +22,8 @@ root_path = '.'
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0): def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
env = gym.make(env_name) env = gym.make(env_name)
use_sde = False use_sde = False
ppo = PPO( ppo = TRL_PG(
"MlpPolicy", MlpPolicy,
env, env,
verbose=0, verbose=0,
tensorboard_log=root_path+"/logs_tb/" + tensorboard_log=root_path+"/logs_tb/" +
@ -36,8 +38,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
clip_range=0.2, clip_range=0.2,
) )
trl_pg = TRL_PG( trl_pg = TRL_PG(
"MlpPolicy", MlpPolicy,
env, env,
projection=FrobeniusProjectionLayer(),
verbose=0, verbose=0,
tensorboard_log=root_path+"/logs_tb/"+env_name + tensorboard_log=root_path+"/logs_tb/"+env_name +
"/trl_pg"+(['', '_sde'][use_sde])+"/", "/trl_pg"+(['', '_sde'][use_sde])+"/",
@ -54,9 +57,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
print('TRL_PG:') print('TRL_PG:')
testModel(trl_pg, timesteps, showRes, testModel(trl_pg, timesteps, showRes,
saveModel, n_eval_episodes) saveModel, n_eval_episodes)
# print('PPO:') print('PPO:')
# testModel(ppo, timesteps, showRes, testModel(ppo, timesteps, showRes,
# saveModel, n_eval_episodes) saveModel, n_eval_episodes)
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16): def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):