diff --git a/test.py b/test.py index dbe7a53..f01ffec 100644 --- a/test.py +++ b/test.py @@ -1,4 +1,5 @@ import gym +from gym.envs.registration import register import numpy as np import time @@ -6,9 +7,17 @@ from stable_baselines3 import SAC, PPO, A2C from stable_baselines3.common.evaluation import evaluate_policy from sb3_trl.trl_pg import TRL_PG +from subtrees.columbus import env + +register( + id='Columbus_Test3.1-v0', + entry_point=env.ColumbusEnv, + max_episode_steps=1000, +) def main(): - env = gym.make("LunarLander-v2") + #env = gym.make("LunarLander-v2") + env = gym.make("Columbus_test3.1-v0") ppo = PPO( "MlpPolicy", env,