Fixed replay
This commit is contained in:
parent
28561b9bb2
commit
30c9e93967
16
replay.py
16
replay.py
@ -27,23 +27,13 @@ def main(load_path, n_eval_episodes=0):
|
||||
print(env_name, alg_name, alg_deriv, use_sde)
|
||||
env = gym.make(env_name)
|
||||
if alg_name == 'ppo':
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
use_sde=use_sde,
|
||||
)
|
||||
Model = PPO
|
||||
elif alg_name == 'trl' and alg_deriv == 'pg':
|
||||
model = TRL_PG(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
use_sde=use_sde,
|
||||
)
|
||||
Model = TRL_PG
|
||||
else:
|
||||
raise Exception('Algorithm not implemented for replay')
|
||||
|
||||
print(model.get_parameters())
|
||||
|
||||
model.load(load_path, env=env)
|
||||
model = Model.load(load_path, env=env)
|
||||
|
||||
if n_eval_episodes:
|
||||
mean_reward, std_reward = evaluate_policy(
|
||||
|
Loading…
Reference in New Issue
Block a user