#!/bin/python3 import gym from gym.envs.registration import register import numpy as np import os import time import datetime from stable_baselines3 import SAC, PPO, A2C from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy from sb3_trl.trl_pg import TRL_PG import columbus def main(load_path, n_eval_episodes=0): load_path = load_path.replace('.zip','') load_path = load_path.replace("'",'') load_path = load_path.replace(' ','') file_name = load_path.split('/')[-1] # TODO: Ugly, Ugly, Ugly: env_name = file_name.split('_')[0] alg_name = file_name.split('_')[1] alg_deriv = file_name.split('_')[2] use_sde = file_name.find('sde')!=-1 print(env_name, alg_name, alg_deriv, use_sde) env = gym.make(env_name) if alg_name=='ppo': model = PPO( "MlpPolicy", env, use_sde=use_sde, ) elif alg_name=='trl' and alg_deriv=='pg': model = TRL_PG( "MlpPolicy", env, use_sde=use_sde, ) else: raise Exception('Algorithm not implemented for replay') print(model.get_parameters()) model.load(load_path, env=env) if n_eval_episodes: mean_reward, std_reward = evaluate_policy( model, env, n_eval_episodes=n_eval_episodes, deterministic=False) print('Reward: '+str(round(mean_reward, 3)) + '±'+str(round(std_reward, 2))) input('') obs = env.reset() episode_reward = 0 while True: time.sleep(1/30) action, _ = model.predict(obs, deterministic=False) obs, reward, done, info = env.step(action) env.render() episode_reward += reward if done: episode_reward = 0.0 obs = env.reset() env.reset() if __name__ == '__main__': main(input('[path to model> '))