Support SAC in replays
This commit is contained in:
parent
0162a36824
commit
6384d411a9
@ -10,6 +10,7 @@ from stable_baselines3.common.evaluation import evaluate_policy
|
|||||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||||
|
|
||||||
from metastable_baselines.ppo import PPO
|
from metastable_baselines.ppo import PPO
|
||||||
|
from metastable_baselines.sac import SAC
|
||||||
import columbus
|
import columbus
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +27,12 @@ def main(load_path, n_eval_episodes=0):
|
|||||||
print(env_name, alg_name, alg_deriv, use_sde)
|
print(env_name, alg_name, alg_deriv, use_sde)
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
|
|
||||||
model = PPO.load(load_path, env=env)
|
if alg_name == 'ppo':
|
||||||
|
Model = PPO
|
||||||
|
elif alg_name == 'sac':
|
||||||
|
Model = SAC
|
||||||
|
|
||||||
|
model = Model.load(load_path, env=env)
|
||||||
|
|
||||||
show_chol = env_name.startswith('Columbus')
|
show_chol = env_name.startswith('Columbus')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user