Testing SDC
This commit is contained in:
parent
3304fd49f6
commit
1706bea571
15
test.py
15
test.py
@ -11,6 +11,8 @@ from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||
|
||||
from metastable_baselines.trl_pg import TRL_PG
|
||||
from metastable_baselines.trl_pg.policies import MlpPolicy
|
||||
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||
import columbus
|
||||
|
||||
#root_path = os.getcwd()
|
||||
@ -20,8 +22,8 @@ root_path = '.'
|
||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
||||
env = gym.make(env_name)
|
||||
use_sde = False
|
||||
ppo = PPO(
|
||||
"MlpPolicy",
|
||||
ppo = TRL_PG(
|
||||
MlpPolicy,
|
||||
env,
|
||||
verbose=0,
|
||||
tensorboard_log=root_path+"/logs_tb/" +
|
||||
@ -36,8 +38,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
||||
clip_range=0.2,
|
||||
)
|
||||
trl_pg = TRL_PG(
|
||||
"MlpPolicy",
|
||||
MlpPolicy,
|
||||
env,
|
||||
projection=FrobeniusProjectionLayer(),
|
||||
verbose=0,
|
||||
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
||||
"/trl_pg"+(['', '_sde'][use_sde])+"/",
|
||||
@ -54,9 +57,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
||||
print('TRL_PG:')
|
||||
testModel(trl_pg, timesteps, showRes,
|
||||
saveModel, n_eval_episodes)
|
||||
# print('PPO:')
|
||||
# testModel(ppo, timesteps, showRes,
|
||||
# saveModel, n_eval_episodes)
|
||||
print('PPO:')
|
||||
testModel(ppo, timesteps, showRes,
|
||||
saveModel, n_eval_episodes)
|
||||
|
||||
|
||||
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
||||
|
Loading…
Reference in New Issue
Block a user