Support SAC in replays

This commit is contained in:
Dominik Moritz Roth 2022-07-19 10:08:34 +02:00
parent 0162a36824
commit 6384d411a9

View File

@ -10,6 +10,7 @@ from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
from metastable_baselines.ppo import PPO
from metastable_baselines.sac import SAC
import columbus
@ -26,7 +27,12 @@ def main(load_path, n_eval_episodes=0):
print(env_name, alg_name, alg_deriv, use_sde)
env = gym.make(env_name)
model = PPO.load(load_path, env=env)
if alg_name == 'ppo':
Model = PPO
elif alg_name == 'sac':
Model = SAC
model = Model.load(load_path, env=env)
show_chol = env_name.startswith('Columbus')