diff --git a/replay.py b/replay.py index 448799c..1c1f484 100755 --- a/replay.py +++ b/replay.py @@ -15,35 +15,25 @@ import columbus def main(load_path, n_eval_episodes=0): - load_path = load_path.replace('.zip','') - load_path = load_path.replace("'",'') - load_path = load_path.replace(' ','') + load_path = load_path.replace('.zip', '') + load_path = load_path.replace("'", '') + load_path = load_path.replace(' ', '') file_name = load_path.split('/')[-1] # TODO: Ugly, Ugly, Ugly: env_name = file_name.split('_')[0] alg_name = file_name.split('_')[1] alg_deriv = file_name.split('_')[2] - use_sde = file_name.find('sde')!=-1 + use_sde = file_name.find('sde') != -1 print(env_name, alg_name, alg_deriv, use_sde) env = gym.make(env_name) - if alg_name=='ppo': - model = PPO( - "MlpPolicy", - env, - use_sde=use_sde, - ) - elif alg_name=='trl' and alg_deriv=='pg': - model = TRL_PG( - "MlpPolicy", - env, - use_sde=use_sde, - ) + if alg_name == 'ppo': + Model = PPO + elif alg_name == 'trl' and alg_deriv == 'pg': + Model = TRL_PG else: raise Exception('Algorithm not implemented for replay') - print(model.get_parameters()) - - model.load(load_path, env=env) + model = Model.load(load_path, env=env) if n_eval_episodes: mean_reward, std_reward = evaluate_policy(