diff --git a/replay.py b/replay.py index c135b9f..854b159 100755 --- a/replay.py +++ b/replay.py @@ -10,6 +10,7 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy from metastable_baselines.ppo import PPO +from metastable_baselines.sac import SAC import columbus @@ -26,7 +27,12 @@ def main(load_path, n_eval_episodes=0): print(env_name, alg_name, alg_deriv, use_sde) env = gym.make(env_name) - model = PPO.load(load_path, env=env) + if alg_name == 'ppo': + Model = PPO + elif alg_name == 'sac': + Model = SAC + + model = Model.load(load_path, env=env) show_chol = env_name.startswith('Columbus')