ProbSquashing implemented (tanh)
This commit is contained in:
parent
05dad44b6e
commit
199ce0c8cb
@ -20,6 +20,7 @@ from stable_baselines3.common.distributions import (
|
|||||||
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||||
|
|
||||||
from ..misc.tensor_ops import fill_triangular
|
from ..misc.tensor_ops import fill_triangular
|
||||||
|
from ..misc.tanhBijector import TanhBijector
|
||||||
|
|
||||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||||
# TODO: Support Squashed Dists (tanh)
|
# TODO: Support Squashed Dists (tanh)
|
||||||
@ -67,6 +68,9 @@ class ProbSquashingType(Enum):
|
|||||||
def apply(self, x):
|
def apply(self, x):
|
||||||
return [nn.Identity(), th.tanh][self.value](x)
|
return [nn.Identity(), th.tanh][self.value](x)
|
||||||
|
|
||||||
|
def apply_inv(self, x):
|
||||||
|
return [nn.Identity(), TanhBijector.inverse][self.value](x)
|
||||||
|
|
||||||
|
|
||||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||||
allowedEPTs = allowedEPTs or EnforcePositiveType
|
allowedEPTs = allowedEPTs or EnforcePositiveType
|
||||||
@ -130,7 +134,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param action_dim: Dimension of the action space.
|
:param action_dim: Dimension of the action space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
|
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6):
|
||||||
super(UniversalGaussianDistribution, self).__init__()
|
super(UniversalGaussianDistribution, self).__init__()
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.par_strength = neural_strength
|
self.par_strength = neural_strength
|
||||||
@ -139,7 +143,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.enforce_positive_type = enforce_positive_type
|
self.enforce_positive_type = enforce_positive_type
|
||||||
self.prob_squashing_type = prob_squashing_type
|
self.prob_squashing_type = prob_squashing_type
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
self.distribution = None
|
self.distribution = None
|
||||||
|
self.gaussian_actions = None
|
||||||
|
|
||||||
if self.prob_squashing_type != ProbSquashingType.NONE:
|
if self.prob_squashing_type != ProbSquashingType.NONE:
|
||||||
raise Exception('ProbSquasing is not yet implmenented!')
|
raise Exception('ProbSquasing is not yet implmenented!')
|
||||||
@ -209,7 +216,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
raise Exception('Unable to create torch distribution')
|
raise Exception('Unable to create torch distribution')
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
|
||||||
"""
|
"""
|
||||||
Get the log probabilities of actions according to the distribution.
|
Get the log probabilities of actions according to the distribution.
|
||||||
Note that you must first call the ``proba_distribution()`` method.
|
Note that you must first call the ``proba_distribution()`` method.
|
||||||
@ -217,18 +224,37 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param actions:
|
:param actions:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
log_prob = self.distribution.log_prob(actions)
|
if self.prob_squashing_type == ProbSquashingType.NONE:
|
||||||
return log_prob
|
log_prob = self.distribution.log_prob(actions)
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
if gaussian_actions is None:
|
||||||
|
# It will be clipped to avoid NaN when inversing tanh
|
||||||
|
gaussian_actions = self.prob_squashing_type.apply_inv(actions)
|
||||||
|
|
||||||
|
log_prob = self.distribution.log_prob(gaussian_actions)
|
||||||
|
|
||||||
|
if self.prob_squashing_type == ProbSquashingType.TANH:
|
||||||
|
log_prob -= th.sum(th.log(1 - actions **
|
||||||
|
2 + self.epsilon), dim=1)
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
def entropy(self) -> th.Tensor:
|
def entropy(self) -> th.Tensor:
|
||||||
|
# TODO: This will return incorrect results when using prob-squashing
|
||||||
return self.distribution.entropy()
|
return self.distribution.entropy()
|
||||||
|
|
||||||
def sample(self) -> th.Tensor:
|
def sample(self) -> th.Tensor:
|
||||||
# Reparametrization trick to pass gradients
|
# Reparametrization trick to pass gradients
|
||||||
return self.distribution.rsample()
|
sample = self.distribution.rsample()
|
||||||
|
self.gaussian_actions = sample
|
||||||
|
return self.prob_squashing_type.apply(sample)
|
||||||
|
|
||||||
def mode(self) -> th.Tensor:
|
def mode(self) -> th.Tensor:
|
||||||
return self.distribution.mean
|
mode = self.distribution.mean
|
||||||
|
self.gaussian_actions = mode
|
||||||
|
return self.prob_squashing_type.apply(mode)
|
||||||
|
|
||||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
|
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
|
||||||
# Update the proba distribution
|
# Update the proba distribution
|
||||||
@ -245,7 +271,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
actions = self.actions_from_params(mean_actions, log_std)
|
actions = self.actions_from_params(mean_actions, log_std)
|
||||||
log_prob = self.log_prob(actions)
|
log_prob = self.log_prob(actions, self.gaussian_actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
|
|
||||||
|
44
metastable_baselines/misc/tanhBijector.py
Normal file
44
metastable_baselines/misc/tanhBijector.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import torch as th
|
||||||
|
|
||||||
|
|
||||||
|
class TanhBijector:
|
||||||
|
"""
|
||||||
|
Stolen from SB3
|
||||||
|
|
||||||
|
Bijective transformation of a probability distribution
|
||||||
|
using a squashing function (tanh)
|
||||||
|
TODO: use Pyro instead (https://pyro.ai/)
|
||||||
|
:param epsilon: small value to avoid NaN due to numerical imprecision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(x: th.Tensor) -> th.Tensor:
|
||||||
|
return th.tanh(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def atanh(x: th.Tensor) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Inverse of Tanh
|
||||||
|
Taken from Pyro: https://github.com/pyro-ppl/pyro
|
||||||
|
0.5 * torch.log((1 + x ) / (1 - x))
|
||||||
|
"""
|
||||||
|
return 0.5 * (x.log1p() - (-x).log1p())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inverse(y: th.Tensor) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Inverse tanh.
|
||||||
|
:param y:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
eps = th.finfo(y.dtype).eps
|
||||||
|
# Clip the action to avoid NaN
|
||||||
|
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
|
||||||
|
|
||||||
|
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
|
||||||
|
# Squash correction (from original SAC implementation)
|
||||||
|
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
|
Loading…
Reference in New Issue
Block a user