diff --git a/replay.py b/replay.py index f4554fc..c135b9f 100755 --- a/replay.py +++ b/replay.py @@ -28,6 +28,8 @@ def main(load_path, n_eval_episodes=0): model = PPO.load(load_path, env=env) + show_chol = env_name.startswith('Columbus') + if n_eval_episodes: mean_reward, std_reward = evaluate_policy( model, env, n_eval_episodes=n_eval_episodes, deterministic=False) @@ -41,7 +43,10 @@ def main(load_path, n_eval_episodes=0): time.sleep(1/30) action, _ = model.predict(obs, deterministic=False) obs, reward, done, info = env.step(action) - env.render() + if show_chol: + env.render(chol=model.policy.chol) + else: + env.render() episode_reward += reward if done: episode_reward = 0.0