Fixed SAC+SDE+SDC bugs

This commit is contained in:
Dominik Moritz Roth 2022-09-03 13:08:31 +02:00
parent 4532135812
commit ee4a0eed56
4 changed files with 25 additions and 24 deletions

View File

@ -320,12 +320,12 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.gaussian_actions = mode
return self.prob_squashing_type.apply(mode)
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_sde=None) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi)
self.proba_distribution(mean_actions, log_std, latent_sde=latent_sde)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde=None) -> Tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
@ -334,7 +334,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param log_std:
:return:
"""
actions = self.actions_from_params(mean_actions, log_std)
actions = self.actions_from_params(
mean_actions, log_std, latent_sde=latent_sde)
log_prob = self.log_prob(actions, self.gaussian_actions)
return actions, log_prob

View File

@ -167,7 +167,7 @@ class Actor(BasePolicy):
"""
msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist,
StateDependentNoiseDistribution), msg
StateDependentNoiseDistribution) or (isinstance(self.action_dist, UniversalGaussianDistribution) and self.action_dist.use_sde), msg
return self.chol
def reset_noise(self, n_envs: int = 1) -> None:
@ -199,12 +199,13 @@ class Actor(BasePolicy):
latent_pi = self.latent_pi(features)
mean_actions = self.mu_net(latent_pi)
if self.use_sde:
return mean_actions, self.chol, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation)
chol = self.chol_net(latent_pi)
self.chol = chol
# Original Implementation to cap the standard deviation
self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
# self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
if self.use_sde:
return mean_actions, self.chol, dict(latent_sde=latent_pi)
return mean_actions, self.chol, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:

View File

@ -262,9 +262,8 @@ class SAC(OffPolicyAlgorithm):
latent_pi = act.latent_pi(features)
mean_actions = act.mu_net(latent_pi)
# TODO: Allow contextual covariance with sde
if self.use_sde:
chol = act.chol
chol = act.chol_net(latent_pi)
else:
# Unstructured exploration (Original implementation)
chol = act.chol_net(latent_pi)
@ -275,8 +274,8 @@ class SAC(OffPolicyAlgorithm):
act_dist = self.actor.action_dist
# internal A
if self.use_sde:
actions_pi = self.actions_from_params(
mean_actions, chol, latent_pi) # latent_pi = latent_sde
actions_pi = act_dist.actions_from_params(
mean_actions, chol, latent_sde=latent_pi) # latent_pi = latent_sde
else:
actions_pi = act_dist.actions_from_params(
mean_actions, chol)

24
test.py
View File

@ -22,26 +22,26 @@ root_path = '.'
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
env = gym.make(env_name)
use_sde = False
use_sde = True
# th.autograd.set_detect_anomaly(True)
ppo = PPO(
MlpPolicyPPO,
sac = SAC(
MlpPolicySAC,
env,
# KLProjectionLayer(trust_region_coeff=0.01),
projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
#projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type':
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
verbose=0,
tensorboard_log=root_path+"/logs_tb/" +
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
env_name+"/sac"+(['', '_sde'][use_sde])+"/",
learning_rate=3e-4, # 3e-4,
gamma=0.99,
gae_lambda=0.95,
normalize_advantage=True,
ent_coef=0.1, # 0.1
vf_coef=0.5,
#gae_lambda=0.95,
#normalize_advantage=True,
#ent_coef=0.1, # 0.1
#vf_coef=0.5,
use_sde=use_sde, # False
clip_range=None # 1 # 0.2,
#clip_range=None # 1 # 0.2,
)
# trl_frob = PPO(
# MlpPolicy,
@ -60,8 +60,8 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
# clip_range=2, # 0.2
# )
print('PPO:')
testModel(ppo, timesteps, showRes,
print('SAC:')
testModel(sac, timesteps, showRes,
saveModel, n_eval_episodes)
# print('TRL_frob:')
# testModel(trl_frob, timesteps, showRes,