Tiny fix for other envs
This commit is contained in:
parent
302dbf6dde
commit
4ec5c65cf2
@ -136,6 +136,8 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
self.normalize_images = normalize_images
|
self.normalize_images = normalize_images
|
||||||
self.log_std_init = log_std_init
|
self.log_std_init = log_std_init
|
||||||
# Keyword arguments for gSDE distribution
|
# Keyword arguments for gSDE distribution
|
||||||
|
if dist_kwargs == None:
|
||||||
|
dist_kwargs = {}
|
||||||
if use_sde:
|
if use_sde:
|
||||||
add_dist_kwargs = {
|
add_dist_kwargs = {
|
||||||
'use_sde': True,
|
'use_sde': True,
|
||||||
|
66
test.py
66
test.py
@ -24,49 +24,49 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
|
|||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
use_sde = True
|
use_sde = True
|
||||||
# th.autograd.set_detect_anomaly(True)
|
# th.autograd.set_detect_anomaly(True)
|
||||||
sac = SAC(
|
#sac = SAC(
|
||||||
MlpPolicySAC,
|
# MlpPolicySAC,
|
||||||
env,
|
# env,
|
||||||
# KLProjectionLayer(trust_region_coeff=0.01),
|
# KLProjectionLayer(trust_region_coeff=0.01),
|
||||||
#projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
|
#projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
|
||||||
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
# policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
||||||
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
# ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||||
verbose=0,
|
# verbose=0,
|
||||||
tensorboard_log=root_path+"/logs_tb/" +
|
# tensorboard_log=root_path+"/logs_tb/" +
|
||||||
env_name+"/sac"+(['', '_sde'][use_sde])+"/",
|
# env_name+"/sac"+(['', '_sde'][use_sde])+"/",
|
||||||
learning_rate=3e-4, # 3e-4,
|
# learning_rate=3e-4, # 3e-4,
|
||||||
gamma=0.99,
|
# gamma=0.99,
|
||||||
#gae_lambda=0.95,
|
#gae_lambda=0.95,
|
||||||
#normalize_advantage=True,
|
#normalize_advantage=True,
|
||||||
#ent_coef=0.1, # 0.1
|
#ent_coef=0.1, # 0.1
|
||||||
#vf_coef=0.5,
|
#vf_coef=0.5,
|
||||||
use_sde=use_sde, # False
|
# use_sde=use_sde, # False
|
||||||
sde_sample_freq=8,
|
# sde_sample_freq=8,
|
||||||
#clip_range=None # 1 # 0.2,
|
#clip_range=None # 1 # 0.2,
|
||||||
)
|
|
||||||
# trl_frob = PPO(
|
|
||||||
# MlpPolicy,
|
|
||||||
# env,
|
|
||||||
# projection=FrobeniusProjectionLayer(),
|
|
||||||
# verbose=0,
|
|
||||||
# tensorboard_log=root_path+"/logs_tb/"+env_name +
|
|
||||||
# "/trl_frob"+(['', '_sde'][use_sde])+"/",
|
|
||||||
# learning_rate=3e-4,
|
|
||||||
# gamma=0.99,
|
|
||||||
# gae_lambda=0.95,
|
|
||||||
# normalize_advantage=True,
|
|
||||||
# ent_coef=0.03, # 0.1
|
|
||||||
# vf_coef=0.5,
|
|
||||||
# use_sde=use_sde,
|
|
||||||
# clip_range=2, # 0.2
|
|
||||||
#)
|
#)
|
||||||
|
trl_frob = PPO(
|
||||||
|
MlpPolicyPPO,
|
||||||
|
env,
|
||||||
|
projection=FrobeniusProjectionLayer(),
|
||||||
|
verbose=0,
|
||||||
|
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
||||||
|
"/trl_frob"+(['', '_sde'][use_sde])+"/",
|
||||||
|
learning_rate=3e-4,
|
||||||
|
gamma=0.99,
|
||||||
|
gae_lambda=0.95,
|
||||||
|
normalize_advantage=True,
|
||||||
|
ent_coef=0.03, # 0.1
|
||||||
|
vf_coef=0.5,
|
||||||
|
use_sde=use_sde,
|
||||||
|
clip_range=2, # 0.2
|
||||||
|
)
|
||||||
|
|
||||||
print('SAC:')
|
#print('SAC:')
|
||||||
testModel(sac, timesteps, showRes,
|
#testModel(sac, timesteps, showRes,
|
||||||
saveModel, n_eval_episodes)
|
|
||||||
# print('TRL_frob:')
|
|
||||||
# testModel(trl_frob, timesteps, showRes,
|
|
||||||
# saveModel, n_eval_episodes)
|
# saveModel, n_eval_episodes)
|
||||||
|
print('TRL_frob:')
|
||||||
|
testModel(trl_frob, timesteps, showRes,
|
||||||
|
saveModel, n_eval_episodes)
|
||||||
|
|
||||||
|
|
||||||
def full(env_name='ColumbusCandyland_Aux10-v0', timesteps=200_000, saveModel=True, n_eval_episodes=4):
|
def full(env_name='ColumbusCandyland_Aux10-v0', timesteps=200_000, saveModel=True, n_eval_episodes=4):
|
||||||
|
Loading…
Reference in New Issue
Block a user