Fixed SAC+SDE+SDC bugs
This commit is contained in:
parent
4532135812
commit
ee4a0eed56
@ -320,12 +320,12 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
self.gaussian_actions = mode
|
||||
return self.prob_squashing_type.apply(mode)
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
|
||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_sde=None) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi)
|
||||
self.proba_distribution(mean_actions, log_std, latent_sde=latent_sde)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde=None) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Compute the log probability of taking an action
|
||||
given the distribution parameters.
|
||||
@ -334,7 +334,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
:param log_std:
|
||||
:return:
|
||||
"""
|
||||
actions = self.actions_from_params(mean_actions, log_std)
|
||||
actions = self.actions_from_params(
|
||||
mean_actions, log_std, latent_sde=latent_sde)
|
||||
log_prob = self.log_prob(actions, self.gaussian_actions)
|
||||
return actions, log_prob
|
||||
|
||||
|
@ -167,7 +167,7 @@ class Actor(BasePolicy):
|
||||
"""
|
||||
msg = "get_std() is only available when using gSDE"
|
||||
assert isinstance(self.action_dist,
|
||||
StateDependentNoiseDistribution), msg
|
||||
StateDependentNoiseDistribution) or (isinstance(self.action_dist, UniversalGaussianDistribution) and self.action_dist.use_sde), msg
|
||||
return self.chol
|
||||
|
||||
def reset_noise(self, n_envs: int = 1) -> None:
|
||||
@ -199,12 +199,13 @@ class Actor(BasePolicy):
|
||||
latent_pi = self.latent_pi(features)
|
||||
mean_actions = self.mu_net(latent_pi)
|
||||
|
||||
if self.use_sde:
|
||||
return mean_actions, self.chol, dict(latent_sde=latent_pi)
|
||||
# Unstructured exploration (Original implementation)
|
||||
chol = self.chol_net(latent_pi)
|
||||
self.chol = chol
|
||||
# Original Implementation to cap the standard deviation
|
||||
self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
||||
# self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
||||
if self.use_sde:
|
||||
return mean_actions, self.chol, dict(latent_sde=latent_pi)
|
||||
return mean_actions, self.chol, {}
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
|
@ -262,9 +262,8 @@ class SAC(OffPolicyAlgorithm):
|
||||
latent_pi = act.latent_pi(features)
|
||||
mean_actions = act.mu_net(latent_pi)
|
||||
|
||||
# TODO: Allow contextual covariance with sde
|
||||
if self.use_sde:
|
||||
chol = act.chol
|
||||
chol = act.chol_net(latent_pi)
|
||||
else:
|
||||
# Unstructured exploration (Original implementation)
|
||||
chol = act.chol_net(latent_pi)
|
||||
@ -275,8 +274,8 @@ class SAC(OffPolicyAlgorithm):
|
||||
act_dist = self.actor.action_dist
|
||||
# internal A
|
||||
if self.use_sde:
|
||||
actions_pi = self.actions_from_params(
|
||||
mean_actions, chol, latent_pi) # latent_pi = latent_sde
|
||||
actions_pi = act_dist.actions_from_params(
|
||||
mean_actions, chol, latent_sde=latent_pi) # latent_pi = latent_sde
|
||||
else:
|
||||
actions_pi = act_dist.actions_from_params(
|
||||
mean_actions, chol)
|
||||
|
24
test.py
24
test.py
@ -22,26 +22,26 @@ root_path = '.'
|
||||
|
||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
||||
env = gym.make(env_name)
|
||||
use_sde = False
|
||||
use_sde = True
|
||||
# th.autograd.set_detect_anomaly(True)
|
||||
ppo = PPO(
|
||||
MlpPolicyPPO,
|
||||
sac = SAC(
|
||||
MlpPolicySAC,
|
||||
env,
|
||||
# KLProjectionLayer(trust_region_coeff=0.01),
|
||||
projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
|
||||
#projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
|
||||
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
||||
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||
verbose=0,
|
||||
tensorboard_log=root_path+"/logs_tb/" +
|
||||
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
|
||||
env_name+"/sac"+(['', '_sde'][use_sde])+"/",
|
||||
learning_rate=3e-4, # 3e-4,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
normalize_advantage=True,
|
||||
ent_coef=0.1, # 0.1
|
||||
vf_coef=0.5,
|
||||
#gae_lambda=0.95,
|
||||
#normalize_advantage=True,
|
||||
#ent_coef=0.1, # 0.1
|
||||
#vf_coef=0.5,
|
||||
use_sde=use_sde, # False
|
||||
clip_range=None # 1 # 0.2,
|
||||
#clip_range=None # 1 # 0.2,
|
||||
)
|
||||
# trl_frob = PPO(
|
||||
# MlpPolicy,
|
||||
@ -60,8 +60,8 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
|
||||
# clip_range=2, # 0.2
|
||||
# )
|
||||
|
||||
print('PPO:')
|
||||
testModel(ppo, timesteps, showRes,
|
||||
print('SAC:')
|
||||
testModel(sac, timesteps, showRes,
|
||||
saveModel, n_eval_episodes)
|
||||
# print('TRL_frob:')
|
||||
# testModel(trl_frob, timesteps, showRes,
|
||||
|
Loading…
Reference in New Issue
Block a user