45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
|
import torch as th
|
||
|
|
||
|
|
||
|
class TanhBijector:
|
||
|
"""
|
||
|
Stolen from SB3
|
||
|
|
||
|
Bijective transformation of a probability distribution
|
||
|
using a squashing function (tanh)
|
||
|
TODO: use Pyro instead (https://pyro.ai/)
|
||
|
:param epsilon: small value to avoid NaN due to numerical imprecision.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, epsilon: float = 1e-6):
|
||
|
super().__init__()
|
||
|
self.epsilon = epsilon
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(x: th.Tensor) -> th.Tensor:
|
||
|
return th.tanh(x)
|
||
|
|
||
|
@staticmethod
|
||
|
def atanh(x: th.Tensor) -> th.Tensor:
|
||
|
"""
|
||
|
Inverse of Tanh
|
||
|
Taken from Pyro: https://github.com/pyro-ppl/pyro
|
||
|
0.5 * torch.log((1 + x ) / (1 - x))
|
||
|
"""
|
||
|
return 0.5 * (x.log1p() - (-x).log1p())
|
||
|
|
||
|
@staticmethod
|
||
|
def inverse(y: th.Tensor) -> th.Tensor:
|
||
|
"""
|
||
|
Inverse tanh.
|
||
|
:param y:
|
||
|
:return:
|
||
|
"""
|
||
|
eps = th.finfo(y.dtype).eps
|
||
|
# Clip the action to avoid NaN
|
||
|
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
|
||
|
|
||
|
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
|
||
|
# Squash correction (from original SAC implementation)
|
||
|
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
|