From 30c9e9396754b2fbe05afa2c64df5788ab114b49 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 29 Jun 2022 17:02:40 +0200 Subject: [PATCH] Fixed replay --- replay.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/replay.py b/replay.py index 448799c..1c1f484 100755 --- a/replay.py +++ b/replay.py @@ -15,35 +15,25 @@ import columbus def main(load_path, n_eval_episodes=0): - load_path = load_path.replace('.zip','') - load_path = load_path.replace("'",'') - load_path = load_path.replace(' ','') + load_path = load_path.replace('.zip', '') + load_path = load_path.replace("'", '') + load_path = load_path.replace(' ', '') file_name = load_path.split('/')[-1] # TODO: Ugly, Ugly, Ugly: env_name = file_name.split('_')[0] alg_name = file_name.split('_')[1] alg_deriv = file_name.split('_')[2] - use_sde = file_name.find('sde')!=-1 + use_sde = file_name.find('sde') != -1 print(env_name, alg_name, alg_deriv, use_sde) env = gym.make(env_name) - if alg_name=='ppo': - model = PPO( - "MlpPolicy", - env, - use_sde=use_sde, - ) - elif alg_name=='trl' and alg_deriv=='pg': - model = TRL_PG( - "MlpPolicy", - env, - use_sde=use_sde, - ) + if alg_name == 'ppo': + Model = PPO + elif alg_name == 'trl' and alg_deriv == 'pg': + Model = TRL_PG else: raise Exception('Algorithm not implemented for replay') - print(model.get_parameters()) - - model.load(load_path, env=env) + model = Model.load(load_path, env=env) if n_eval_episodes: mean_reward, std_reward = evaluate_policy(