ProbSquashing implemented (tanh)

This commit is contained in:
Dominik Moritz Roth 2022-07-20 10:32:19 +02:00
parent 05dad44b6e
commit 199ce0c8cb
2 changed files with 77 additions and 7 deletions

View File

@ -20,6 +20,7 @@ from stable_baselines3.common.distributions import (
from stable_baselines3.common.distributions import DiagGaussianDistribution from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.tensor_ops import fill_triangular from ..misc.tensor_ops import fill_triangular
from ..misc.tanhBijector import TanhBijector
# TODO: Integrate and Test what I currently have before adding more complexity # TODO: Integrate and Test what I currently have before adding more complexity
# TODO: Support Squashed Dists (tanh) # TODO: Support Squashed Dists (tanh)
@ -67,6 +68,9 @@ class ProbSquashingType(Enum):
def apply(self, x): def apply(self, x):
return [nn.Identity(), th.tanh][self.value](x) return [nn.Identity(), th.tanh][self.value](x)
def apply_inv(self, x):
return [nn.Identity(), TanhBijector.inverse][self.value](x)
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
allowedEPTs = allowedEPTs or EnforcePositiveType allowedEPTs = allowedEPTs or EnforcePositiveType
@ -130,7 +134,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param action_dim: Dimension of the action space. :param action_dim: Dimension of the action space.
""" """
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE): def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6):
super(UniversalGaussianDistribution, self).__init__() super(UniversalGaussianDistribution, self).__init__()
self.action_dim = action_dim self.action_dim = action_dim
self.par_strength = neural_strength self.par_strength = neural_strength
@ -139,7 +143,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.enforce_positive_type = enforce_positive_type self.enforce_positive_type = enforce_positive_type
self.prob_squashing_type = prob_squashing_type self.prob_squashing_type = prob_squashing_type
self.epsilon = epsilon
self.distribution = None self.distribution = None
self.gaussian_actions = None
if self.prob_squashing_type != ProbSquashingType.NONE: if self.prob_squashing_type != ProbSquashingType.NONE:
raise Exception('ProbSquasing is not yet implmenented!') raise Exception('ProbSquasing is not yet implmenented!')
@ -209,7 +216,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
raise Exception('Unable to create torch distribution') raise Exception('Unable to create torch distribution')
return self return self
def log_prob(self, actions: th.Tensor) -> th.Tensor: def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
""" """
Get the log probabilities of actions according to the distribution. Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method. Note that you must first call the ``proba_distribution()`` method.
@ -217,18 +224,37 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param actions: :param actions:
:return: :return:
""" """
log_prob = self.distribution.log_prob(actions) if self.prob_squashing_type == ProbSquashingType.NONE:
return log_prob log_prob = self.distribution.log_prob(actions)
return log_prob
if gaussian_actions is None:
# It will be clipped to avoid NaN when inversing tanh
gaussian_actions = self.prob_squashing_type.apply_inv(actions)
log_prob = self.distribution.log_prob(gaussian_actions)
if self.prob_squashing_type == ProbSquashingType.TANH:
log_prob -= th.sum(th.log(1 - actions **
2 + self.epsilon), dim=1)
return log_prob
raise Exception()
def entropy(self) -> th.Tensor: def entropy(self) -> th.Tensor:
# TODO: This will return incorrect results when using prob-squashing
return self.distribution.entropy() return self.distribution.entropy()
def sample(self) -> th.Tensor: def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients # Reparametrization trick to pass gradients
return self.distribution.rsample() sample = self.distribution.rsample()
self.gaussian_actions = sample
return self.prob_squashing_type.apply(sample)
def mode(self) -> th.Tensor: def mode(self) -> th.Tensor:
return self.distribution.mean mode = self.distribution.mean
self.gaussian_actions = mode
return self.prob_squashing_type.apply(mode)
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor: def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
# Update the proba distribution # Update the proba distribution
@ -245,7 +271,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
:return: :return:
""" """
actions = self.actions_from_params(mean_actions, log_std) actions = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(actions) log_prob = self.log_prob(actions, self.gaussian_actions)
return actions, log_prob return actions, log_prob

View File

@ -0,0 +1,44 @@
import torch as th
class TanhBijector:
"""
Stolen from SB3
Bijective transformation of a probability distribution
using a squashing function (tanh)
TODO: use Pyro instead (https://pyro.ai/)
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon
@staticmethod
def forward(x: th.Tensor) -> th.Tensor:
return th.tanh(x)
@staticmethod
def atanh(x: th.Tensor) -> th.Tensor:
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
return 0.5 * (x.log1p() - (-x).log1p())
@staticmethod
def inverse(y: th.Tensor) -> th.Tensor:
"""
Inverse tanh.
:param y:
:return:
"""
eps = th.finfo(y.dtype).eps
# Clip the action to avoid NaN
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
# Squash correction (from original SAC implementation)
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)