diff --git a/replay.py b/replay.py new file mode 100755 index 0000000..448799c --- /dev/null +++ b/replay.py @@ -0,0 +1,70 @@ +#!/bin/python3 +import gym +from gym.envs.registration import register +import numpy as np +import os +import time +import datetime + +from stable_baselines3 import SAC, PPO, A2C +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy + +from sb3_trl.trl_pg import TRL_PG +import columbus + + +def main(load_path, n_eval_episodes=0): + load_path = load_path.replace('.zip','') + load_path = load_path.replace("'",'') + load_path = load_path.replace(' ','') + file_name = load_path.split('/')[-1] + # TODO: Ugly, Ugly, Ugly: + env_name = file_name.split('_')[0] + alg_name = file_name.split('_')[1] + alg_deriv = file_name.split('_')[2] + use_sde = file_name.find('sde')!=-1 + print(env_name, alg_name, alg_deriv, use_sde) + env = gym.make(env_name) + if alg_name=='ppo': + model = PPO( + "MlpPolicy", + env, + use_sde=use_sde, + ) + elif alg_name=='trl' and alg_deriv=='pg': + model = TRL_PG( + "MlpPolicy", + env, + use_sde=use_sde, + ) + else: + raise Exception('Algorithm not implemented for replay') + + print(model.get_parameters()) + + model.load(load_path, env=env) + + if n_eval_episodes: + mean_reward, std_reward = evaluate_policy( + model, env, n_eval_episodes=n_eval_episodes, deterministic=False) + print('Reward: '+str(round(mean_reward, 3)) + + '±'+str(round(std_reward, 2))) + + input('') + obs = env.reset() + episode_reward = 0 + while True: + time.sleep(1/30) + action, _ = model.predict(obs, deterministic=False) + obs, reward, done, info = env.step(action) + env.render() + episode_reward += reward + if done: + episode_reward = 0.0 + obs = env.reset() + env.reset() + + +if __name__ == '__main__': + main(input('[path to model> '))