Fixed replay
This commit is contained in:
parent
28561b9bb2
commit
30c9e93967
28
replay.py
28
replay.py
@ -15,35 +15,25 @@ import columbus
|
|||||||
|
|
||||||
|
|
||||||
def main(load_path, n_eval_episodes=0):
|
def main(load_path, n_eval_episodes=0):
|
||||||
load_path = load_path.replace('.zip','')
|
load_path = load_path.replace('.zip', '')
|
||||||
load_path = load_path.replace("'",'')
|
load_path = load_path.replace("'", '')
|
||||||
load_path = load_path.replace(' ','')
|
load_path = load_path.replace(' ', '')
|
||||||
file_name = load_path.split('/')[-1]
|
file_name = load_path.split('/')[-1]
|
||||||
# TODO: Ugly, Ugly, Ugly:
|
# TODO: Ugly, Ugly, Ugly:
|
||||||
env_name = file_name.split('_')[0]
|
env_name = file_name.split('_')[0]
|
||||||
alg_name = file_name.split('_')[1]
|
alg_name = file_name.split('_')[1]
|
||||||
alg_deriv = file_name.split('_')[2]
|
alg_deriv = file_name.split('_')[2]
|
||||||
use_sde = file_name.find('sde')!=-1
|
use_sde = file_name.find('sde') != -1
|
||||||
print(env_name, alg_name, alg_deriv, use_sde)
|
print(env_name, alg_name, alg_deriv, use_sde)
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
if alg_name=='ppo':
|
if alg_name == 'ppo':
|
||||||
model = PPO(
|
Model = PPO
|
||||||
"MlpPolicy",
|
elif alg_name == 'trl' and alg_deriv == 'pg':
|
||||||
env,
|
Model = TRL_PG
|
||||||
use_sde=use_sde,
|
|
||||||
)
|
|
||||||
elif alg_name=='trl' and alg_deriv=='pg':
|
|
||||||
model = TRL_PG(
|
|
||||||
"MlpPolicy",
|
|
||||||
env,
|
|
||||||
use_sde=use_sde,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise Exception('Algorithm not implemented for replay')
|
raise Exception('Algorithm not implemented for replay')
|
||||||
|
|
||||||
print(model.get_parameters())
|
model = Model.load(load_path, env=env)
|
||||||
|
|
||||||
model.load(load_path, env=env)
|
|
||||||
|
|
||||||
if n_eval_episodes:
|
if n_eval_episodes:
|
||||||
mean_reward, std_reward = evaluate_policy(
|
mean_reward, std_reward = evaluate_policy(
|
||||||
|
Loading…
Reference in New Issue
Block a user