diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 90534cd..ae9adba 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -20,6 +20,7 @@ from stable_baselines3.common.distributions import ( from stable_baselines3.common.distributions import DiagGaussianDistribution from ..misc.tensor_ops import fill_triangular +from ..misc.tanhBijector import TanhBijector # TODO: Integrate and Test what I currently have before adding more complexity # TODO: Support Squashed Dists (tanh) @@ -67,6 +68,9 @@ class ProbSquashingType(Enum): def apply(self, x): return [nn.Identity(), th.tanh][self.value](x) + def apply_inv(self, x): + return [nn.Identity(), TanhBijector.inverse][self.value](x) + def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): allowedEPTs = allowedEPTs or EnforcePositiveType @@ -130,7 +134,7 @@ class UniversalGaussianDistribution(SB3_Distribution): :param action_dim: Dimension of the action space. """ - def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE): + def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6): super(UniversalGaussianDistribution, self).__init__() self.action_dim = action_dim self.par_strength = neural_strength @@ -139,7 +143,10 @@ class UniversalGaussianDistribution(SB3_Distribution): self.enforce_positive_type = enforce_positive_type self.prob_squashing_type = prob_squashing_type + self.epsilon = epsilon + self.distribution = None + self.gaussian_actions = None if self.prob_squashing_type != ProbSquashingType.NONE: raise Exception('ProbSquasing is not yet implmenented!') @@ -209,7 +216,7 @@ class UniversalGaussianDistribution(SB3_Distribution): raise Exception('Unable to create torch distribution') return self - def log_prob(self, actions: th.Tensor) -> th.Tensor: + def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: """ Get the log probabilities of actions according to the distribution. Note that you must first call the ``proba_distribution()`` method. @@ -217,18 +224,37 @@ class UniversalGaussianDistribution(SB3_Distribution): :param actions: :return: """ - log_prob = self.distribution.log_prob(actions) - return log_prob + if self.prob_squashing_type == ProbSquashingType.NONE: + log_prob = self.distribution.log_prob(actions) + return log_prob + + if gaussian_actions is None: + # It will be clipped to avoid NaN when inversing tanh + gaussian_actions = self.prob_squashing_type.apply_inv(actions) + + log_prob = self.distribution.log_prob(gaussian_actions) + + if self.prob_squashing_type == ProbSquashingType.TANH: + log_prob -= th.sum(th.log(1 - actions ** + 2 + self.epsilon), dim=1) + return log_prob + + raise Exception() def entropy(self) -> th.Tensor: + # TODO: This will return incorrect results when using prob-squashing return self.distribution.entropy() def sample(self) -> th.Tensor: # Reparametrization trick to pass gradients - return self.distribution.rsample() + sample = self.distribution.rsample() + self.gaussian_actions = sample + return self.prob_squashing_type.apply(sample) def mode(self) -> th.Tensor: - return self.distribution.mean + mode = self.distribution.mean + self.gaussian_actions = mode + return self.prob_squashing_type.apply(mode) def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor: # Update the proba distribution @@ -245,7 +271,7 @@ class UniversalGaussianDistribution(SB3_Distribution): :return: """ actions = self.actions_from_params(mean_actions, log_std) - log_prob = self.log_prob(actions) + log_prob = self.log_prob(actions, self.gaussian_actions) return actions, log_prob diff --git a/metastable_baselines/misc/tanhBijector.py b/metastable_baselines/misc/tanhBijector.py new file mode 100644 index 0000000..d8ac168 --- /dev/null +++ b/metastable_baselines/misc/tanhBijector.py @@ -0,0 +1,44 @@ +import torch as th + + +class TanhBijector: + """ + Stolen from SB3 + + Bijective transformation of a probability distribution + using a squashing function (tanh) + TODO: use Pyro instead (https://pyro.ai/) + :param epsilon: small value to avoid NaN due to numerical imprecision. + """ + + def __init__(self, epsilon: float = 1e-6): + super().__init__() + self.epsilon = epsilon + + @staticmethod + def forward(x: th.Tensor) -> th.Tensor: + return th.tanh(x) + + @staticmethod + def atanh(x: th.Tensor) -> th.Tensor: + """ + Inverse of Tanh + Taken from Pyro: https://github.com/pyro-ppl/pyro + 0.5 * torch.log((1 + x ) / (1 - x)) + """ + return 0.5 * (x.log1p() - (-x).log1p()) + + @staticmethod + def inverse(y: th.Tensor) -> th.Tensor: + """ + Inverse tanh. + :param y: + :return: + """ + eps = th.finfo(y.dtype).eps + # Clip the action to avoid NaN + return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps)) + + def log_prob_correction(self, x: th.Tensor) -> th.Tensor: + # Squash correction (from original SAC implementation) + return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)