Fixed replay

This commit is contained in:
Dominik Moritz Roth 2022-06-29 17:02:40 +02:00
parent 28561b9bb2
commit 30c9e93967

View File

@ -15,35 +15,25 @@ import columbus
def main(load_path, n_eval_episodes=0): def main(load_path, n_eval_episodes=0):
load_path = load_path.replace('.zip','') load_path = load_path.replace('.zip', '')
load_path = load_path.replace("'",'') load_path = load_path.replace("'", '')
load_path = load_path.replace(' ','') load_path = load_path.replace(' ', '')
file_name = load_path.split('/')[-1] file_name = load_path.split('/')[-1]
# TODO: Ugly, Ugly, Ugly: # TODO: Ugly, Ugly, Ugly:
env_name = file_name.split('_')[0] env_name = file_name.split('_')[0]
alg_name = file_name.split('_')[1] alg_name = file_name.split('_')[1]
alg_deriv = file_name.split('_')[2] alg_deriv = file_name.split('_')[2]
use_sde = file_name.find('sde')!=-1 use_sde = file_name.find('sde') != -1
print(env_name, alg_name, alg_deriv, use_sde) print(env_name, alg_name, alg_deriv, use_sde)
env = gym.make(env_name) env = gym.make(env_name)
if alg_name=='ppo': if alg_name == 'ppo':
model = PPO( Model = PPO
"MlpPolicy", elif alg_name == 'trl' and alg_deriv == 'pg':
env, Model = TRL_PG
use_sde=use_sde,
)
elif alg_name=='trl' and alg_deriv=='pg':
model = TRL_PG(
"MlpPolicy",
env,
use_sde=use_sde,
)
else: else:
raise Exception('Algorithm not implemented for replay') raise Exception('Algorithm not implemented for replay')
print(model.get_parameters()) model = Model.load(load_path, env=env)
model.load(load_path, env=env)
if n_eval_episodes: if n_eval_episodes:
mean_reward, std_reward = evaluate_policy( mean_reward, std_reward = evaluate_policy(