ProbSquashing implemented (tanh)
This commit is contained in:
parent
05dad44b6e
commit
199ce0c8cb
@ -20,6 +20,7 @@ from stable_baselines3.common.distributions import (
|
||||
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||
|
||||
from ..misc.tensor_ops import fill_triangular
|
||||
from ..misc.tanhBijector import TanhBijector
|
||||
|
||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||
# TODO: Support Squashed Dists (tanh)
|
||||
@ -67,6 +68,9 @@ class ProbSquashingType(Enum):
|
||||
def apply(self, x):
|
||||
return [nn.Identity(), th.tanh][self.value](x)
|
||||
|
||||
def apply_inv(self, x):
|
||||
return [nn.Identity(), TanhBijector.inverse][self.value](x)
|
||||
|
||||
|
||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||
allowedEPTs = allowedEPTs or EnforcePositiveType
|
||||
@ -130,7 +134,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
:param action_dim: Dimension of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
|
||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6):
|
||||
super(UniversalGaussianDistribution, self).__init__()
|
||||
self.action_dim = action_dim
|
||||
self.par_strength = neural_strength
|
||||
@ -139,7 +143,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
self.enforce_positive_type = enforce_positive_type
|
||||
self.prob_squashing_type = prob_squashing_type
|
||||
|
||||
self.epsilon = epsilon
|
||||
|
||||
self.distribution = None
|
||||
self.gaussian_actions = None
|
||||
|
||||
if self.prob_squashing_type != ProbSquashingType.NONE:
|
||||
raise Exception('ProbSquasing is not yet implmenented!')
|
||||
@ -209,7 +216,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
raise Exception('Unable to create torch distribution')
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilities of actions according to the distribution.
|
||||
Note that you must first call the ``proba_distribution()`` method.
|
||||
@ -217,18 +224,37 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
:param actions:
|
||||
:return:
|
||||
"""
|
||||
if self.prob_squashing_type == ProbSquashingType.NONE:
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return log_prob
|
||||
|
||||
if gaussian_actions is None:
|
||||
# It will be clipped to avoid NaN when inversing tanh
|
||||
gaussian_actions = self.prob_squashing_type.apply_inv(actions)
|
||||
|
||||
log_prob = self.distribution.log_prob(gaussian_actions)
|
||||
|
||||
if self.prob_squashing_type == ProbSquashingType.TANH:
|
||||
log_prob -= th.sum(th.log(1 - actions **
|
||||
2 + self.epsilon), dim=1)
|
||||
return log_prob
|
||||
|
||||
raise Exception()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
# TODO: This will return incorrect results when using prob-squashing
|
||||
return self.distribution.entropy()
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
return self.distribution.rsample()
|
||||
sample = self.distribution.rsample()
|
||||
self.gaussian_actions = sample
|
||||
return self.prob_squashing_type.apply(sample)
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
mode = self.distribution.mean
|
||||
self.gaussian_actions = mode
|
||||
return self.prob_squashing_type.apply(mode)
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
@ -245,7 +271,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
:return:
|
||||
"""
|
||||
actions = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(actions)
|
||||
log_prob = self.log_prob(actions, self.gaussian_actions)
|
||||
return actions, log_prob
|
||||
|
||||
|
||||
|
44
metastable_baselines/misc/tanhBijector.py
Normal file
44
metastable_baselines/misc/tanhBijector.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch as th
|
||||
|
||||
|
||||
class TanhBijector:
|
||||
"""
|
||||
Stolen from SB3
|
||||
|
||||
Bijective transformation of a probability distribution
|
||||
using a squashing function (tanh)
|
||||
TODO: use Pyro instead (https://pyro.ai/)
|
||||
:param epsilon: small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float = 1e-6):
|
||||
super().__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
@staticmethod
|
||||
def forward(x: th.Tensor) -> th.Tensor:
|
||||
return th.tanh(x)
|
||||
|
||||
@staticmethod
|
||||
def atanh(x: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Inverse of Tanh
|
||||
Taken from Pyro: https://github.com/pyro-ppl/pyro
|
||||
0.5 * torch.log((1 + x ) / (1 - x))
|
||||
"""
|
||||
return 0.5 * (x.log1p() - (-x).log1p())
|
||||
|
||||
@staticmethod
|
||||
def inverse(y: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Inverse tanh.
|
||||
:param y:
|
||||
:return:
|
||||
"""
|
||||
eps = th.finfo(y.dtype).eps
|
||||
# Clip the action to avoid NaN
|
||||
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
|
||||
|
||||
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
|
||||
# Squash correction (from original SAC implementation)
|
||||
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
|
Loading…
Reference in New Issue
Block a user