metastable-baselines/metastable_baselines/misc/tanhBijector.py

45 lines
1.2 KiB
Python

import torch as th
class TanhBijector:
"""
Stolen from SB3
Bijective transformation of a probability distribution
using a squashing function (tanh)
TODO: use Pyro instead (https://pyro.ai/)
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon
@staticmethod
def forward(x: th.Tensor) -> th.Tensor:
return th.tanh(x)
@staticmethod
def atanh(x: th.Tensor) -> th.Tensor:
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
return 0.5 * (x.log1p() - (-x).log1p())
@staticmethod
def inverse(y: th.Tensor) -> th.Tensor:
"""
Inverse tanh.
:param y:
:return:
"""
eps = th.finfo(y.dtype).eps
# Clip the action to avoid NaN
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
# Squash correction (from original SAC implementation)
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)