2022-06-30 20:40:30 +02:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
2022-07-01 11:29:12 +02:00
|
|
|
from enum import Enum
|
2022-06-30 20:40:30 +02:00
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
import gym
|
2022-06-30 20:40:30 +02:00
|
|
|
import torch as th
|
|
|
|
from torch import nn
|
2022-07-15 15:03:51 +02:00
|
|
|
from torch.distributions import Normal, Independent, MultivariateNormal
|
2022-07-11 17:28:08 +02:00
|
|
|
from math import pi
|
2022-06-30 20:40:30 +02:00
|
|
|
|
|
|
|
from stable_baselines3.common.preprocessing import get_action_dim
|
|
|
|
|
2022-07-01 11:29:12 +02:00
|
|
|
from stable_baselines3.common.distributions import sum_independent_dims
|
2022-06-30 20:40:30 +02:00
|
|
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
2022-07-13 19:38:57 +02:00
|
|
|
from stable_baselines3.common.distributions import (
|
|
|
|
BernoulliDistribution,
|
|
|
|
CategoricalDistribution,
|
|
|
|
MultiCategoricalDistribution,
|
|
|
|
# StateDependentNoiseDistribution,
|
|
|
|
)
|
2022-06-30 20:40:30 +02:00
|
|
|
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
|
|
|
|
2022-07-11 17:28:08 +02:00
|
|
|
from ..misc.tensor_ops import fill_triangular
|
2022-07-20 10:32:19 +02:00
|
|
|
from ..misc.tanhBijector import TanhBijector
|
2022-06-30 20:40:30 +02:00
|
|
|
|
2022-07-09 14:33:07 +02:00
|
|
|
# TODO: Integrate and Test what I currently have before adding more complexity
|
|
|
|
# TODO: Support Squashed Dists (tanh)
|
2022-07-01 11:29:12 +02:00
|
|
|
# TODO: Contextual Cov
|
|
|
|
# TODO: - Hybrid
|
|
|
|
# TODO: Contextual SDE (Scalar + Diag + Full)
|
|
|
|
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
|
|
|
|
|
2022-07-09 12:26:39 +02:00
|
|
|
|
2022-07-01 11:29:12 +02:00
|
|
|
class Strength(Enum):
|
|
|
|
NONE = 0
|
|
|
|
SCALAR = 1
|
|
|
|
DIAG = 2
|
|
|
|
FULL = 3
|
|
|
|
|
|
|
|
|
|
|
|
class ParametrizationType(Enum):
|
2022-07-15 15:03:51 +02:00
|
|
|
NONE = 0
|
2022-07-01 15:14:41 +02:00
|
|
|
CHOL = 1
|
2022-07-11 17:28:08 +02:00
|
|
|
SPHERICAL_CHOL = 2
|
2022-07-13 19:38:57 +02:00
|
|
|
# Not (yet?) implemented:
|
2022-07-15 15:46:31 +02:00
|
|
|
# GIVENS = 3
|
|
|
|
# NNLN_EIGEN = 4
|
2022-07-01 11:29:12 +02:00
|
|
|
|
|
|
|
|
|
|
|
class EnforcePositiveType(Enum):
|
2022-07-15 18:45:38 +02:00
|
|
|
# This need to be implemented in this ugly fashion,
|
|
|
|
# because cloudpickle does not like more complex enums
|
2022-07-11 11:55:23 +02:00
|
|
|
|
2022-07-15 18:45:38 +02:00
|
|
|
NONE = 0
|
|
|
|
SOFTPLUS = 1
|
|
|
|
ABS = 2
|
|
|
|
RELU = 3
|
|
|
|
LOG = 4
|
2022-07-11 11:55:23 +02:00
|
|
|
|
|
|
|
def apply(self, x):
|
2022-07-15 18:45:38 +02:00
|
|
|
# aaaaaa
|
|
|
|
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
2022-07-09 14:03:56 +02:00
|
|
|
|
|
|
|
|
2022-07-09 14:33:07 +02:00
|
|
|
class ProbSquashingType(Enum):
|
2022-07-15 18:45:38 +02:00
|
|
|
NONE = 0
|
|
|
|
TANH = 1
|
2022-07-11 11:55:23 +02:00
|
|
|
|
|
|
|
def apply(self, x):
|
2022-07-15 18:45:38 +02:00
|
|
|
return [nn.Identity(), th.tanh][self.value](x)
|
2022-07-09 14:33:07 +02:00
|
|
|
|
2022-07-20 10:32:19 +02:00
|
|
|
def apply_inv(self, x):
|
|
|
|
return [nn.Identity(), TanhBijector.inverse][self.value](x)
|
|
|
|
|
2022-07-09 14:33:07 +02:00
|
|
|
|
2022-08-05 21:06:31 +02:00
|
|
|
def cast_to_enum(inp, Class):
|
|
|
|
if isinstance(inp, Enum):
|
|
|
|
return inp
|
|
|
|
else:
|
|
|
|
return Class[inp]
|
|
|
|
|
|
|
|
|
2022-07-09 14:33:07 +02:00
|
|
|
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
2022-07-09 14:03:56 +02:00
|
|
|
allowedEPTs = allowedEPTs or EnforcePositiveType
|
|
|
|
allowedParStrength = allowedParStrength or Strength
|
|
|
|
allowedCovStrength = allowedCovStrength or Strength
|
|
|
|
allowedPTs = allowedPTs or ParametrizationType
|
2022-07-09 14:33:07 +02:00
|
|
|
allowedPSTs = allowedPSTs or ProbSquashingType
|
2022-07-09 14:03:56 +02:00
|
|
|
|
|
|
|
for ps in allowedParStrength:
|
|
|
|
for cs in allowedCovStrength:
|
|
|
|
if ps.value > cs.value:
|
|
|
|
continue
|
2022-07-19 10:06:40 +02:00
|
|
|
if cs == Strength.NONE:
|
|
|
|
yield (ps, cs, EnforcePositiveType.NONE, ParametrizationType.NONE)
|
|
|
|
else:
|
|
|
|
for ept in allowedEPTs:
|
|
|
|
if cs == Strength.FULL:
|
|
|
|
for pt in allowedPTs:
|
|
|
|
if pt != ParametrizationType.NONE:
|
|
|
|
yield (ps, cs, ept, pt)
|
|
|
|
else:
|
|
|
|
yield (ps, cs, ept, ParametrizationType.NONE)
|
2022-07-01 11:29:12 +02:00
|
|
|
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
def make_proba_distribution(
|
|
|
|
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
|
|
|
) -> SB3_Distribution:
|
|
|
|
"""
|
|
|
|
Return an instance of Distribution for the correct type of action space
|
|
|
|
:param action_space: the input action space
|
|
|
|
:param use_sde: Force the use of StateDependentNoiseDistribution
|
|
|
|
instead of DiagGaussianDistribution
|
|
|
|
:param dist_kwargs: Keyword arguments to pass to the probability distribution
|
|
|
|
:return: the appropriate Distribution object
|
|
|
|
"""
|
|
|
|
if dist_kwargs is None:
|
|
|
|
dist_kwargs = {}
|
|
|
|
|
2022-08-06 14:36:35 +02:00
|
|
|
dist_kwargs['use_sde'] = use_sde
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
if isinstance(action_space, gym.spaces.Box):
|
|
|
|
assert len(
|
|
|
|
action_space.shape) == 1, "Error: the action space must be a vector"
|
2022-08-06 14:36:35 +02:00
|
|
|
return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
|
2022-07-13 19:38:57 +02:00
|
|
|
elif isinstance(action_space, gym.spaces.Discrete):
|
|
|
|
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
|
|
|
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
|
|
|
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
|
|
|
|
elif isinstance(action_space, gym.spaces.MultiBinary):
|
|
|
|
return BernoulliDistribution(action_space.n, **dist_kwargs)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Error: probability distribution, not implemented for action space"
|
|
|
|
f"of type {type(action_space)}."
|
|
|
|
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-07-01 11:29:12 +02:00
|
|
|
class UniversalGaussianDistribution(SB3_Distribution):
|
2022-06-30 20:40:30 +02:00
|
|
|
"""
|
2022-07-01 11:29:12 +02:00
|
|
|
Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions.
|
2022-06-30 20:40:30 +02:00
|
|
|
|
|
|
|
:param action_dim: Dimension of the action space.
|
|
|
|
"""
|
|
|
|
|
2022-08-17 19:31:54 +02:00
|
|
|
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-3, sde_learn_features=False):
|
2022-07-01 11:29:12 +02:00
|
|
|
super(UniversalGaussianDistribution, self).__init__()
|
2022-07-13 19:38:57 +02:00
|
|
|
self.action_dim = action_dim
|
2022-08-05 21:06:31 +02:00
|
|
|
self.par_strength = cast_to_enum(neural_strength, Strength)
|
|
|
|
self.cov_strength = cast_to_enum(cov_strength, Strength)
|
|
|
|
self.par_type = cast_to_enum(
|
|
|
|
parameterization_type, ParametrizationType)
|
|
|
|
self.enforce_positive_type = cast_to_enum(
|
|
|
|
enforce_positive_type, EnforcePositiveType)
|
|
|
|
self.prob_squashing_type = cast_to_enum(
|
2022-08-06 14:36:35 +02:00
|
|
|
prob_squashing_type, ProbSquashingType)
|
2022-06-30 20:40:30 +02:00
|
|
|
|
2022-07-20 10:32:19 +02:00
|
|
|
self.epsilon = epsilon
|
|
|
|
|
2022-07-01 15:14:41 +02:00
|
|
|
self.distribution = None
|
2022-07-20 10:32:19 +02:00
|
|
|
self.gaussian_actions = None
|
2022-07-01 11:29:12 +02:00
|
|
|
|
2022-08-10 11:54:52 +02:00
|
|
|
self.use_sde = use_sde
|
2022-08-14 16:10:22 +02:00
|
|
|
self.learn_features = sde_learn_features
|
|
|
|
|
2022-08-06 14:36:35 +02:00
|
|
|
assert (self.par_type != ParametrizationType.NONE) == (
|
|
|
|
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
2022-07-15 15:03:51 +02:00
|
|
|
|
2022-07-19 10:06:40 +02:00
|
|
|
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
|
|
|
|
raise Exception(
|
|
|
|
'You need to specify an enforce_positive_type for spherical_cholesky')
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
|
2022-07-09 14:45:35 +02:00
|
|
|
p = self.distribution
|
2022-07-15 15:03:51 +02:00
|
|
|
if isinstance(p, Independent):
|
2022-07-13 19:38:57 +02:00
|
|
|
if p.stddev.shape != chol.shape:
|
|
|
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
2022-07-15 15:03:51 +02:00
|
|
|
np = Independent(Normal(mean, chol), 1)
|
|
|
|
elif isinstance(p, MultivariateNormal):
|
|
|
|
np = MultivariateNormal(mean, scale_tril=chol)
|
2022-08-14 18:42:19 +02:00
|
|
|
new = UniversalGaussianDistribution(self.action_dim, use_sde=self.use_sde, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
|
|
|
parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type, epsilon=self.epsilon, sde_learn_features=self.learn_features)
|
2022-07-09 14:45:35 +02:00
|
|
|
new.distribution = np
|
|
|
|
|
|
|
|
return new
|
|
|
|
|
2022-08-06 21:25:49 +02:00
|
|
|
def new_dist_like_me_from_sqrt(self, mean: th.Tensor, cov_sqrt: th.Tensor):
|
|
|
|
chol = self._sqrt_to_chol(cov_sqrt)
|
|
|
|
|
|
|
|
new = self.new_dist_like_me(mean, chol)
|
|
|
|
|
|
|
|
new.cov_sqrt = cov_sqrt
|
|
|
|
new.distribution.cov_sqrt = cov_sqrt
|
|
|
|
|
|
|
|
return new
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
def proba_distribution_net(self, latent_dim: int, latent_sde_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
2022-07-01 11:29:12 +02:00
|
|
|
"""
|
|
|
|
Create the layers and parameter that represent the distribution:
|
|
|
|
one output will be the mean of the Gaussian, the other parameter will be the
|
2022-07-11 17:28:08 +02:00
|
|
|
standard deviation
|
2022-07-01 11:29:12 +02:00
|
|
|
|
|
|
|
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
2022-07-11 17:28:08 +02:00
|
|
|
:param std_init: Initial value for the standard deviation
|
|
|
|
:return: We return two nn.Modules (mean, chol). chol can be a vector if the full chol would be a diagonal.
|
2022-07-01 11:29:12 +02:00
|
|
|
"""
|
2022-07-09 12:26:39 +02:00
|
|
|
|
2022-07-11 17:28:08 +02:00
|
|
|
assert std_init >= 0.0, "std can not be initialized to a negative value."
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
# TODO: Implement SDE
|
|
|
|
self.latent_sde_dim = latent_sde_dim
|
2022-07-09 12:26:39 +02:00
|
|
|
|
2022-07-01 11:29:12 +02:00
|
|
|
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
2022-07-13 19:38:57 +02:00
|
|
|
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
|
2022-08-16 20:02:33 +02:00
|
|
|
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type, self.epsilon)
|
2022-07-13 19:38:57 +02:00
|
|
|
|
2022-08-10 11:54:52 +02:00
|
|
|
if self.use_sde:
|
|
|
|
self.sample_weights(self.action_dim)
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
return mean_actions, chol
|
|
|
|
|
2022-08-06 21:25:49 +02:00
|
|
|
def _sqrt_to_chol(self, cov_sqrt):
|
|
|
|
vec = False
|
2022-08-15 16:55:17 +02:00
|
|
|
nobatch = False
|
|
|
|
if len(cov_sqrt.shape) <= 2:
|
2022-08-06 21:25:49 +02:00
|
|
|
vec = True
|
2022-08-15 16:55:17 +02:00
|
|
|
if len(cov_sqrt.shape) == 1:
|
|
|
|
nobatch = True
|
2022-08-06 21:25:49 +02:00
|
|
|
|
|
|
|
if vec:
|
|
|
|
cov_sqrt = th.diag_embed(cov_sqrt)
|
|
|
|
|
2022-08-15 16:55:17 +02:00
|
|
|
if nobatch:
|
|
|
|
cov = th.mm(cov_sqrt.mT, cov_sqrt)
|
2022-08-17 23:25:24 +02:00
|
|
|
cov += th.eye(cov.shape[-1])*(self.epsilon)
|
2022-08-15 16:55:17 +02:00
|
|
|
else:
|
|
|
|
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
2022-08-17 23:25:24 +02:00
|
|
|
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(self.epsilon)
|
|
|
|
|
2022-08-06 21:25:49 +02:00
|
|
|
chol = th.linalg.cholesky(cov)
|
|
|
|
|
|
|
|
if vec:
|
|
|
|
chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
|
|
|
|
|
|
|
return chol
|
|
|
|
|
2022-08-06 14:46:42 +02:00
|
|
|
def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
|
|
|
"""
|
|
|
|
Create the distribution given its parameters (mean, cov_sqrt)
|
|
|
|
|
|
|
|
:param mean_actions:
|
|
|
|
:param cov_sqrt:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
self.cov_sqrt = cov_sqrt
|
2022-08-06 21:25:49 +02:00
|
|
|
chol = self._sqrt_to_chol(cov_sqrt)
|
|
|
|
self.proba_distribution(mean_actions, chol, latent_pi)
|
|
|
|
self.distribution.cov_sqrt = cov_sqrt
|
|
|
|
return self
|
2022-08-06 14:46:42 +02:00
|
|
|
|
2022-08-14 20:09:10 +02:00
|
|
|
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: th.Tensor) -> "UniversalGaussianDistribution":
|
2022-07-13 19:38:57 +02:00
|
|
|
"""
|
|
|
|
Create the distribution given its parameters (mean, chol)
|
|
|
|
|
|
|
|
:param mean_actions:
|
|
|
|
:param chol:
|
|
|
|
:return:
|
|
|
|
"""
|
2022-08-10 11:54:52 +02:00
|
|
|
if self.use_sde:
|
|
|
|
self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
|
|
|
# TODO: Change variance of dist to include sde-spread
|
2022-07-13 19:38:57 +02:00
|
|
|
|
|
|
|
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
2022-07-15 15:03:51 +02:00
|
|
|
self.distribution = Independent(Normal(mean_actions, chol), 1)
|
2022-07-13 19:38:57 +02:00
|
|
|
elif self.cov_strength in [Strength.FULL]:
|
2022-07-15 15:03:51 +02:00
|
|
|
self.distribution = MultivariateNormal(
|
|
|
|
mean_actions, scale_tril=chol)
|
2022-07-13 19:38:57 +02:00
|
|
|
if self.distribution == None:
|
|
|
|
raise Exception('Unable to create torch distribution')
|
|
|
|
return self
|
|
|
|
|
2022-07-20 10:32:19 +02:00
|
|
|
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
|
2022-07-13 19:38:57 +02:00
|
|
|
"""
|
|
|
|
Get the log probabilities of actions according to the distribution.
|
|
|
|
Note that you must first call the ``proba_distribution()`` method.
|
|
|
|
|
|
|
|
:param actions:
|
|
|
|
:return:
|
|
|
|
"""
|
2022-07-20 10:32:19 +02:00
|
|
|
if self.prob_squashing_type == ProbSquashingType.NONE:
|
|
|
|
log_prob = self.distribution.log_prob(actions)
|
|
|
|
return log_prob
|
|
|
|
|
|
|
|
if gaussian_actions is None:
|
|
|
|
# It will be clipped to avoid NaN when inversing tanh
|
|
|
|
gaussian_actions = self.prob_squashing_type.apply_inv(actions)
|
|
|
|
|
|
|
|
log_prob = self.distribution.log_prob(gaussian_actions)
|
|
|
|
|
|
|
|
if self.prob_squashing_type == ProbSquashingType.TANH:
|
|
|
|
log_prob -= th.sum(th.log(1 - actions **
|
|
|
|
2 + self.epsilon), dim=1)
|
|
|
|
return log_prob
|
|
|
|
|
|
|
|
raise Exception()
|
2022-07-13 19:38:57 +02:00
|
|
|
|
|
|
|
def entropy(self) -> th.Tensor:
|
2022-07-20 10:32:19 +02:00
|
|
|
# TODO: This will return incorrect results when using prob-squashing
|
2022-07-15 15:03:51 +02:00
|
|
|
return self.distribution.entropy()
|
2022-07-13 19:38:57 +02:00
|
|
|
|
|
|
|
def sample(self) -> th.Tensor:
|
2022-08-14 20:09:10 +02:00
|
|
|
if self.use_sde:
|
|
|
|
return self._sample_sde()
|
|
|
|
else:
|
|
|
|
return self._sample_normal()
|
|
|
|
|
|
|
|
def _sample_normal(self) -> th.Tensor:
|
2022-07-13 19:38:57 +02:00
|
|
|
# Reparametrization trick to pass gradients
|
2022-07-20 10:32:19 +02:00
|
|
|
sample = self.distribution.rsample()
|
|
|
|
self.gaussian_actions = sample
|
|
|
|
return self.prob_squashing_type.apply(sample)
|
2022-07-13 19:38:57 +02:00
|
|
|
|
2022-08-14 20:09:10 +02:00
|
|
|
def _sample_sde(self) -> th.Tensor:
|
2022-08-10 11:54:52 +02:00
|
|
|
noise = self.get_noise(self._latent_sde)
|
|
|
|
actions = self.distribution.mean + noise
|
|
|
|
self.gaussian_actions = actions
|
|
|
|
return self.prob_squashing_type.apply(actions)
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
def mode(self) -> th.Tensor:
|
2022-07-20 10:32:19 +02:00
|
|
|
mode = self.distribution.mean
|
|
|
|
self.gaussian_actions = mode
|
|
|
|
return self.prob_squashing_type.apply(mode)
|
2022-07-13 19:38:57 +02:00
|
|
|
|
2022-07-19 10:06:40 +02:00
|
|
|
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
|
2022-07-13 19:38:57 +02:00
|
|
|
# Update the proba distribution
|
2022-07-19 10:06:40 +02:00
|
|
|
self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi)
|
2022-07-13 19:38:57 +02:00
|
|
|
return self.get_actions(deterministic=deterministic)
|
|
|
|
|
|
|
|
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
|
|
|
"""
|
|
|
|
Compute the log probability of taking an action
|
|
|
|
given the distribution parameters.
|
|
|
|
|
|
|
|
:param mean_actions:
|
|
|
|
:param log_std:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
actions = self.actions_from_params(mean_actions, log_std)
|
2022-07-20 10:32:19 +02:00
|
|
|
log_prob = self.log_prob(actions, self.gaussian_actions)
|
2022-07-13 19:38:57 +02:00
|
|
|
return actions, log_prob
|
|
|
|
|
2022-08-14 20:09:10 +02:00
|
|
|
def sample_weights(self, batch_size=1):
|
|
|
|
num_dims = (self.latent_sde_dim, self.action_dim)
|
2022-08-10 11:54:52 +02:00
|
|
|
self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims))
|
|
|
|
# Reparametrization trick to pass gradients
|
|
|
|
self.exploration_mat = self.weights_dist.rsample()
|
|
|
|
# Pre-compute matrices in case of parallel exploration
|
|
|
|
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
|
|
|
|
|
|
|
|
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
|
|
|
|
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
|
|
|
# # TODO: Good idea?
|
|
|
|
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
|
|
|
# Default case: only one exploration matrix
|
|
|
|
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
2022-08-14 20:09:10 +02:00
|
|
|
chol = th.diag_embed(self.distribution.stddev)
|
|
|
|
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
|
2022-08-22 13:36:17 +02:00
|
|
|
p = self.distribution
|
|
|
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
2022-08-22 14:19:40 +02:00
|
|
|
chol = th.diag_embed(self.distribution.stddev)
|
2022-08-22 13:36:17 +02:00
|
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
|
|
chol = p.scale_tril
|
|
|
|
|
2022-08-10 11:54:52 +02:00
|
|
|
# Use batch matrix multiplication for efficient computation
|
|
|
|
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
|
|
|
latent_sde = latent_sde.unsqueeze(dim=1)
|
|
|
|
# (batch_size, 1, n_actions)
|
2022-08-14 20:09:10 +02:00
|
|
|
noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol)
|
2022-08-10 11:54:52 +02:00
|
|
|
return noise.squeeze(dim=1)
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
|
|
|
|
class CholNet(nn.Module):
|
2022-08-16 20:02:33 +02:00
|
|
|
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType, epsilon):
|
2022-07-13 19:38:57 +02:00
|
|
|
super().__init__()
|
|
|
|
self.latent_dim = latent_dim
|
|
|
|
self.action_dim = action_dim
|
|
|
|
|
|
|
|
self.par_strength = par_strength
|
|
|
|
self.cov_strength = cov_strength
|
|
|
|
self.par_type = par_type
|
|
|
|
self.enforce_positive_type = enforce_positive_type
|
|
|
|
self.prob_squashing_type = prob_squashing_type
|
|
|
|
|
2022-08-16 20:02:33 +02:00
|
|
|
self.epsilon = epsilon
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
|
|
|
|
|
|
|
# Yes, this is ugly.
|
|
|
|
# But I don't know how this mess could be elegantly abstracted away...
|
2022-06-30 20:40:30 +02:00
|
|
|
|
2022-07-01 15:14:41 +02:00
|
|
|
if self.par_strength == Strength.NONE:
|
|
|
|
if self.cov_strength == Strength.NONE:
|
2022-07-13 19:38:57 +02:00
|
|
|
self.chol = th.ones(self.action_dim) * std_init
|
2022-07-01 15:14:41 +02:00
|
|
|
elif self.cov_strength == Strength.SCALAR:
|
2022-07-16 13:05:35 +02:00
|
|
|
self.param = nn.Parameter(
|
|
|
|
th.Tensor([std_init]), requires_grad=True)
|
2022-07-01 15:14:41 +02:00
|
|
|
elif self.cov_strength == Strength.DIAG:
|
2022-07-13 19:38:57 +02:00
|
|
|
self.params = nn.Parameter(
|
2022-07-11 17:28:08 +02:00
|
|
|
th.ones(self.action_dim) * std_init, requires_grad=True)
|
2022-07-01 15:14:41 +02:00
|
|
|
elif self.cov_strength == Strength.FULL:
|
2022-07-11 17:28:08 +02:00
|
|
|
# TODO: Init Off-axis differently?
|
2022-07-13 19:38:57 +02:00
|
|
|
self.params = nn.Parameter(
|
2022-07-11 17:28:08 +02:00
|
|
|
th.ones(self._full_params_len) * std_init, requires_grad=True)
|
2022-07-01 15:14:41 +02:00
|
|
|
elif self.par_strength == self.cov_strength:
|
2022-07-11 17:28:08 +02:00
|
|
|
if self.par_strength == Strength.SCALAR:
|
2022-07-13 19:38:57 +02:00
|
|
|
self.std = nn.Linear(latent_dim, 1)
|
2022-07-01 15:14:41 +02:00
|
|
|
elif self.par_strength == Strength.DIAG:
|
2022-07-13 19:38:57 +02:00
|
|
|
self.diag_chol = nn.Linear(latent_dim, self.action_dim)
|
2022-07-01 15:14:41 +02:00
|
|
|
elif self.par_strength == Strength.FULL:
|
2022-07-13 19:38:57 +02:00
|
|
|
self.params = nn.Linear(latent_dim, self._full_params_len)
|
2022-07-15 18:45:38 +02:00
|
|
|
elif self.par_strength.value > self.cov_strength.value:
|
2022-07-01 15:14:41 +02:00
|
|
|
raise Exception(
|
|
|
|
'The parameterization can not be stronger than the actual covariance.')
|
|
|
|
else:
|
|
|
|
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
2022-07-13 19:38:57 +02:00
|
|
|
self.factor = nn.Linear(latent_dim, 1)
|
2022-07-15 18:45:38 +02:00
|
|
|
self.param = nn.Parameter(
|
|
|
|
th.ones(self.action_dim), requires_grad=True)
|
2022-07-09 12:26:39 +02:00
|
|
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
2022-07-16 14:57:34 +02:00
|
|
|
if self.enforce_positive_type == EnforcePositiveType.NONE:
|
|
|
|
raise Exception(
|
|
|
|
'For Hybrid[Diag=>Full] enforce_positive_type has to be not NONE. Otherwise required SPD-contraint can not be ensured for cov.')
|
2022-07-16 13:05:35 +02:00
|
|
|
self.stds = nn.Linear(latent_dim, self.action_dim)
|
|
|
|
self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
|
|
|
|
# TODO: Init Non-zero?
|
|
|
|
self.params = nn.Parameter(
|
|
|
|
th.ones(self._full_params_len - self.action_dim) * 0, requires_grad=True)
|
2022-07-01 15:14:41 +02:00
|
|
|
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
2022-07-16 13:05:35 +02:00
|
|
|
self.factor = nn.Linear(latent_dim, 1)
|
|
|
|
# TODO: Init Off-axis differently?
|
|
|
|
self.params = nn.Parameter(
|
|
|
|
th.ones(self._full_params_len) * std_init, requires_grad=True)
|
2022-07-01 15:14:41 +02:00
|
|
|
else:
|
2022-07-11 17:28:08 +02:00
|
|
|
raise Exception("This Exception can't happen")
|
2022-06-30 20:40:30 +02:00
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
def forward(self, x: th.Tensor) -> th.Tensor:
|
|
|
|
# Ugly mess pt.2:
|
|
|
|
if self.par_strength == Strength.NONE:
|
|
|
|
if self.cov_strength == Strength.NONE:
|
|
|
|
return self.chol
|
|
|
|
elif self.cov_strength == Strength.SCALAR:
|
|
|
|
return self._ensure_positive_func(
|
2022-07-16 13:05:35 +02:00
|
|
|
th.ones(self.action_dim) * self.param[0])
|
2022-07-13 19:38:57 +02:00
|
|
|
elif self.cov_strength == Strength.DIAG:
|
|
|
|
return self._ensure_positive_func(self.params)
|
|
|
|
elif self.cov_strength == Strength.FULL:
|
|
|
|
return self._parameterize_full(self.params)
|
|
|
|
elif self.par_strength == self.cov_strength:
|
|
|
|
if self.par_strength == Strength.SCALAR:
|
|
|
|
std = self.std(x)
|
|
|
|
diag_chol = th.ones(self.action_dim) * std
|
|
|
|
return self._ensure_positive_func(diag_chol)
|
|
|
|
elif self.par_strength == Strength.DIAG:
|
|
|
|
diag_chol = self.diag_chol(x)
|
|
|
|
return self._ensure_positive_func(diag_chol)
|
|
|
|
elif self.par_strength == Strength.FULL:
|
|
|
|
params = self.params(x)
|
|
|
|
return self._parameterize_full(params)
|
|
|
|
else:
|
|
|
|
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
2022-07-16 13:05:35 +02:00
|
|
|
factor = self.factor(x)[0]
|
2022-07-13 19:38:57 +02:00
|
|
|
diag_chol = self._ensure_positive_func(
|
2022-07-16 13:05:35 +02:00
|
|
|
self.param * factor)
|
2022-07-13 19:38:57 +02:00
|
|
|
return diag_chol
|
|
|
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
2022-07-16 13:05:35 +02:00
|
|
|
# TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
|
2022-07-16 14:57:34 +02:00
|
|
|
stds = self._ensure_positive_func(self.stds(x))
|
2022-07-16 13:05:35 +02:00
|
|
|
smol = self._parameterize_full(self.params)
|
|
|
|
big = self.padder(smol)
|
2022-07-16 14:58:29 +02:00
|
|
|
pearson_cor_chol = big + th.eye(stds.shape[-1])
|
2022-07-16 15:17:48 +02:00
|
|
|
pearson_cor = (pearson_cor_chol.T @
|
|
|
|
pearson_cor_chol)
|
|
|
|
if len(stds.shape) > 1:
|
|
|
|
# batched operation, we need to expand
|
|
|
|
pearson_cor = pearson_cor.expand(
|
|
|
|
(stds.shape[0],)+pearson_cor.shape)
|
|
|
|
stds = stds.unsqueeze(2)
|
2022-07-16 15:19:56 +02:00
|
|
|
cov = stds.mT * pearson_cor * stds
|
2022-07-16 13:05:35 +02:00
|
|
|
chol = th.linalg.cholesky(cov)
|
|
|
|
return chol
|
2022-07-13 19:38:57 +02:00
|
|
|
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
2022-07-17 00:47:47 +02:00
|
|
|
# TODO: Maybe possible to improve speed and stability by multiplying with factor in cholesky-form.
|
|
|
|
factor = self._ensure_positive_func(self.factor(x))
|
|
|
|
par_chol = self._parameterize_full(self.params)
|
|
|
|
cov = (par_chol.T @ par_chol)
|
|
|
|
if len(factor) > 1:
|
|
|
|
factor = factor.unsqueeze(2)
|
|
|
|
cov = cov * factor
|
|
|
|
chol = th.linalg.cholesky(cov)
|
|
|
|
return chol
|
2022-07-13 19:38:57 +02:00
|
|
|
raise Exception()
|
2022-06-30 20:40:30 +02:00
|
|
|
|
2022-07-11 17:28:08 +02:00
|
|
|
@property
|
|
|
|
def _full_params_len(self):
|
|
|
|
if self.par_type == ParametrizationType.CHOL:
|
|
|
|
return self._flat_chol_len
|
|
|
|
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
|
|
|
return self._flat_chol_len
|
|
|
|
raise Exception()
|
|
|
|
|
|
|
|
def _parameterize_full(self, params):
|
|
|
|
if self.par_type == ParametrizationType.CHOL:
|
|
|
|
return self._chol_from_flat(params)
|
|
|
|
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
|
|
|
return self._chol_from_flat_sphe_chol(params)
|
|
|
|
raise Exception()
|
|
|
|
|
|
|
|
def _chol_from_flat(self, flat_chol):
|
2022-07-15 15:46:31 +02:00
|
|
|
chol = fill_triangular(flat_chol)
|
2022-07-11 17:28:08 +02:00
|
|
|
return self._ensure_diagonal_positive(chol)
|
|
|
|
|
|
|
|
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
|
|
|
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
2022-07-15 15:46:31 +02:00
|
|
|
sphe_chol = fill_triangular(pos_flat_sphe_chol)
|
2022-07-11 17:28:08 +02:00
|
|
|
chol = self._chol_from_sphe_chol(sphe_chol)
|
|
|
|
return chol
|
|
|
|
|
|
|
|
def _chol_from_sphe_chol(self, sphe_chol):
|
2022-07-13 19:38:57 +02:00
|
|
|
# TODO: Make efficient more
|
2022-07-11 17:28:08 +02:00
|
|
|
# Note:
|
|
|
|
# We must should ensure:
|
|
|
|
# S[i,1] > 0 where i = 1..n
|
|
|
|
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
|
|
|
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
|
|
|
# We ensure < pi by applying tanh*pi to all applicable elements
|
2022-07-16 17:34:25 +02:00
|
|
|
batch = (len(sphe_chol.shape) == 3)
|
2022-08-17 22:55:42 +02:00
|
|
|
batch_size = sphe_chol.shape[0]
|
2022-07-11 17:28:08 +02:00
|
|
|
S = sphe_chol
|
2022-07-16 15:47:09 +02:00
|
|
|
n = sphe_chol.shape[-1]
|
2022-07-11 17:28:08 +02:00
|
|
|
L = th.zeros_like(sphe_chol)
|
|
|
|
for i in range(n):
|
2022-08-10 11:55:08 +02:00
|
|
|
#t = 1
|
|
|
|
t = th.Tensor([1])[0]
|
2022-08-17 22:55:42 +02:00
|
|
|
if batch:
|
|
|
|
t = t.expand((batch_size, 1))
|
2022-07-16 13:07:08 +02:00
|
|
|
#s = ''
|
2022-07-16 13:05:35 +02:00
|
|
|
for j in range(i+1):
|
2022-08-10 11:55:08 +02:00
|
|
|
#maybe_cos = 1
|
|
|
|
maybe_cos = th.Tensor([1])[0]
|
2022-08-17 22:55:42 +02:00
|
|
|
if batch:
|
|
|
|
maybe_cos = maybe_cos.expand((batch_size, 1))
|
2022-07-16 13:07:08 +02:00
|
|
|
#s_maybe_cos = ''
|
2022-07-16 17:34:25 +02:00
|
|
|
if i != j and j < n-1 and i < n:
|
|
|
|
if batch:
|
|
|
|
maybe_cos = th.cos(th.tanh(S[:, i, j+1])*pi)
|
|
|
|
else:
|
|
|
|
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
|
|
|
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
|
|
|
if batch:
|
2022-08-17 22:55:42 +02:00
|
|
|
# try:
|
|
|
|
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
|
|
|
|
# except:
|
|
|
|
# import pdb
|
|
|
|
# pdb.set_trace()
|
2022-07-16 17:34:25 +02:00
|
|
|
else:
|
|
|
|
L[i, j] = S[i, 0] * t * maybe_cos
|
2022-07-16 13:07:08 +02:00
|
|
|
# print('[L_'+str(i+1)+']_'+str(j+1) +
|
|
|
|
# '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
|
2022-07-16 13:05:35 +02:00
|
|
|
if j <= i and j < n-1 and i < n:
|
2022-07-16 17:34:25 +02:00
|
|
|
if batch:
|
2022-08-17 22:55:42 +02:00
|
|
|
tc = t.clone()
|
|
|
|
t = (tc.T * th.sin(th.tanh(S[:, i, j+1])*pi)).T
|
2022-07-16 17:34:25 +02:00
|
|
|
else:
|
|
|
|
t *= th.sin(th.tanh(S[i, j+1])*pi)
|
2022-07-16 13:07:08 +02:00
|
|
|
#s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
|
2022-07-11 17:28:08 +02:00
|
|
|
return L
|
|
|
|
|
|
|
|
def _ensure_positive_func(self, x):
|
2022-08-16 20:02:33 +02:00
|
|
|
return self.enforce_positive_type.apply(x) + self.epsilon
|
2022-07-11 17:28:08 +02:00
|
|
|
|
|
|
|
def _ensure_diagonal_positive(self, chol):
|
|
|
|
if len(chol.shape) == 1:
|
|
|
|
# If our chol is a vector (representing a diagonal chol)
|
|
|
|
return self._ensure_positive_func(chol)
|
|
|
|
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
|
|
|
|
dim2=-1)).diag_embed() + chol.triu(1)
|
2022-07-09 12:26:39 +02:00
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
def string(self):
|
|
|
|
# TODO
|
|
|
|
return '<CholNet />'
|
2022-06-30 20:40:30 +02:00
|
|
|
|
|
|
|
|
2022-07-13 19:38:57 +02:00
|
|
|
AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution]
|