import torch as th class TanhBijector: """ Stolen from SB3 Bijective transformation of a probability distribution using a squashing function (tanh) TODO: use Pyro instead (https://pyro.ai/) :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__(self, epsilon: float = 1e-6): super().__init__() self.epsilon = epsilon @staticmethod def forward(x: th.Tensor) -> th.Tensor: return th.tanh(x) @staticmethod def atanh(x: th.Tensor) -> th.Tensor: """ Inverse of Tanh Taken from Pyro: https://github.com/pyro-ppl/pyro 0.5 * torch.log((1 + x ) / (1 - x)) """ return 0.5 * (x.log1p() - (-x).log1p()) @staticmethod def inverse(y: th.Tensor) -> th.Tensor: """ Inverse tanh. :param y: :return: """ eps = th.finfo(y.dtype).eps # Clip the action to avoid NaN return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps)) def log_prob_correction(self, x: th.Tensor) -> th.Tensor: # Squash correction (from original SAC implementation) return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)