Testing SDC
This commit is contained in:
parent
3304fd49f6
commit
1706bea571
15
test.py
15
test.py
@ -11,6 +11,8 @@ from stable_baselines3.common.evaluation import evaluate_policy
|
|||||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||||
|
|
||||||
from metastable_baselines.trl_pg import TRL_PG
|
from metastable_baselines.trl_pg import TRL_PG
|
||||||
|
from metastable_baselines.trl_pg.policies import MlpPolicy
|
||||||
|
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||||
import columbus
|
import columbus
|
||||||
|
|
||||||
#root_path = os.getcwd()
|
#root_path = os.getcwd()
|
||||||
@ -20,8 +22,8 @@ root_path = '.'
|
|||||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
use_sde = False
|
use_sde = False
|
||||||
ppo = PPO(
|
ppo = TRL_PG(
|
||||||
"MlpPolicy",
|
MlpPolicy,
|
||||||
env,
|
env,
|
||||||
verbose=0,
|
verbose=0,
|
||||||
tensorboard_log=root_path+"/logs_tb/" +
|
tensorboard_log=root_path+"/logs_tb/" +
|
||||||
@ -36,8 +38,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
|||||||
clip_range=0.2,
|
clip_range=0.2,
|
||||||
)
|
)
|
||||||
trl_pg = TRL_PG(
|
trl_pg = TRL_PG(
|
||||||
"MlpPolicy",
|
MlpPolicy,
|
||||||
env,
|
env,
|
||||||
|
projection=FrobeniusProjectionLayer(),
|
||||||
verbose=0,
|
verbose=0,
|
||||||
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
||||||
"/trl_pg"+(['', '_sde'][use_sde])+"/",
|
"/trl_pg"+(['', '_sde'][use_sde])+"/",
|
||||||
@ -54,9 +57,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
|||||||
print('TRL_PG:')
|
print('TRL_PG:')
|
||||||
testModel(trl_pg, timesteps, showRes,
|
testModel(trl_pg, timesteps, showRes,
|
||||||
saveModel, n_eval_episodes)
|
saveModel, n_eval_episodes)
|
||||||
# print('PPO:')
|
print('PPO:')
|
||||||
# testModel(ppo, timesteps, showRes,
|
testModel(ppo, timesteps, showRes,
|
||||||
# saveModel, n_eval_episodes)
|
saveModel, n_eval_episodes)
|
||||||
|
|
||||||
|
|
||||||
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
||||||
|
Loading…
Reference in New Issue
Block a user