import torch as th

class TanhBijector:
    Stolen from SB3

    Bijective transformation of a probability distribution
    using a squashing function (tanh)
    TODO: use Pyro instead (
    :param epsilon: small value to avoid NaN due to numerical imprecision.

    def __init__(self, epsilon: float = 1e-6):
        self.epsilon = epsilon

    def forward(x: th.Tensor) -> th.Tensor:
        return th.tanh(x)

    def atanh(x: th.Tensor) -> th.Tensor:
        Inverse of Tanh
        Taken from Pyro:
        0.5 * torch.log((1 + x ) / (1 - x))
        return 0.5 * (x.log1p() - (-x).log1p())

    def inverse(y: th.Tensor) -> th.Tensor:
        Inverse tanh.
        :param y:
        eps = th.finfo(y.dtype).eps
        # Clip the action to avoid NaN
        return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))

    def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
        # Squash correction (from original SAC implementation)
        return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)