Fixed replay

This commit is contained in:
Dominik Moritz Roth 2022-06-29 17:02:40 +02:00
parent 28561b9bb2
commit 30c9e93967

View File

@ -27,23 +27,13 @@ def main(load_path, n_eval_episodes=0):
print(env_name, alg_name, alg_deriv, use_sde)
env = gym.make(env_name)
if alg_name == 'ppo':
model = PPO(
"MlpPolicy",
env,
use_sde=use_sde,
)
Model = PPO
elif alg_name == 'trl' and alg_deriv == 'pg':
model = TRL_PG(
"MlpPolicy",
env,
use_sde=use_sde,
)
Model = TRL_PG
else:
raise Exception('Algorithm not implemented for replay')
print(model.get_parameters())
model.load(load_path, env=env)
model = Model.load(load_path, env=env)
if n_eval_episodes:
mean_reward, std_reward = evaluate_policy(