Fixes + spherical_chol
This commit is contained in:
parent
e4440428f8
commit
41e4170b2f
@ -10,6 +10,8 @@ such Third Party IP, are set forth below.
|
||||
Overview
|
||||
--------------------------------------------------------------------------
|
||||
|
||||
# TODO: Tensorflow-Probability
|
||||
|
||||
# TODO: TrustRegionLayers
|
||||
|
||||
boschresearch/trust-region-layers
|
||||
|
@ -4,6 +4,7 @@ from enum import Enum
|
||||
import torch as th
|
||||
from torch import nn
|
||||
from torch.distributions import Normal, MultivariateNormal
|
||||
from math import pi
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
||||
@ -13,6 +14,7 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||
|
||||
from ..misc.fakeModule import FakeModule
|
||||
from ..misc.distTools import new_dist_like
|
||||
from ..misc.tensor_ops import fill_triangular
|
||||
|
||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||
# TODO: Support Squashed Dists (tanh)
|
||||
@ -30,9 +32,8 @@ class Strength(Enum):
|
||||
|
||||
|
||||
class ParametrizationType(Enum):
|
||||
# Currently only Chol is implemented
|
||||
CHOL = 1
|
||||
#SPHERICAL_CHOL = 2
|
||||
SPHERICAL_CHOL = 2
|
||||
#GIVENS = 3
|
||||
|
||||
|
||||
@ -77,6 +78,9 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
||||
if ps == Strength.SCALAR and cs == Strength.FULL:
|
||||
# TODO: Maybe allow?
|
||||
continue
|
||||
if ps == Strength.DIAG and cs == Strength.FULL:
|
||||
# TODO: Implement
|
||||
continue
|
||||
if ps == Strength.NONE:
|
||||
yield (ps, cs, None, None)
|
||||
else:
|
||||
@ -95,70 +99,72 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
:param action_dim: Dimension of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
def __init__(self, action_dim: int, neural_strength=Strength.DIAG, cov_strength=Strength.DIAG, parameterization_type=Strength.CHOL, enforce_positive_type=EnforcePositiveType.ABS, prob_squashing_type=ProbSquashingType.TANH):
|
||||
super(UniversalGaussianDistribution, self).__init__()
|
||||
self.par_strength = Strength.DIAG
|
||||
self.cov_strength = Strength.DIAG
|
||||
self.par_type = ParametrizationType.CHOL
|
||||
self.enforce_positive_type = EnforcePositiveType.LOG
|
||||
self.prob_squashing_type = ProbSquashingType.TANH
|
||||
self.par_strength = neural_strength
|
||||
self.cov_strength = cov_strength
|
||||
self.par_type = parameterization_type
|
||||
self.enforce_positive_type = enforce_positive_type
|
||||
self.prob_squashing_type = prob_squashing_type
|
||||
|
||||
self.distribution = None
|
||||
|
||||
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
||||
|
||||
def new_dist_like_me(self, mean, pseudo_chol):
|
||||
p = self.distribution
|
||||
np = new_dist_like(p, mean, pseudo_chol)
|
||||
new = UniversalGaussianDistribution(self.action_dim)
|
||||
new.par_strength = self.par_strength
|
||||
new.cov_strength = self.cov_strength
|
||||
new.par_type = self.par_type
|
||||
new.enforce_positive_type = self.enforce_positive_type
|
||||
new.prob_squashing_type = self.prob_squashing_type
|
||||
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
||||
parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
|
||||
new.distribution = np
|
||||
|
||||
return new
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
||||
def proba_distribution_net(self, latent_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the mean of the Gaussian, the other parameter will be the
|
||||
standard deviation (log std in fact to allow negative values)
|
||||
standard deviation
|
||||
|
||||
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:return: We return two nn.Modules (mean, chol).
|
||||
:param std_init: Initial value for the standard deviation
|
||||
:return: We return two nn.Modules (mean, chol). chol can be a vector if the full chol would be a diagonal.
|
||||
"""
|
||||
|
||||
assert std_init >= 0.0, "std can not be initialized to a negative value."
|
||||
|
||||
# TODO: Allow chol to be vector when only diagonal.
|
||||
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
|
||||
if self.par_strength == Strength.NONE:
|
||||
if self.cov_strength == Strength.NONE:
|
||||
pseudo_cov_par = th.ones(self.action_dim) * log_std_init
|
||||
pseudo_cov_par = th.ones(self.action_dim) * std_init
|
||||
elif self.cov_strength == Strength.SCALAR:
|
||||
pseudo_cov_par = th.ones(self.action_dim) * \
|
||||
nn.Parameter(log_std_init, requires_grad=True)
|
||||
nn.Parameter(std_init, requires_grad=True)
|
||||
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
|
||||
elif self.cov_strength == Strength.DIAG:
|
||||
pseudo_cov_par = nn.Parameter(
|
||||
th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
||||
th.ones(self.action_dim) * std_init, requires_grad=True)
|
||||
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
|
||||
elif self.cov_strength == Strength.FULL:
|
||||
# TODO: This won't work, need to ensure SPD!
|
||||
# TODO: Off-axis init?
|
||||
pseudo_cov_par = nn.Parameter(
|
||||
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
|
||||
# TODO: Init Off-axis differently?
|
||||
param = nn.Parameter(
|
||||
th.ones(self._full_params_len) * std_init, requires_grad=True)
|
||||
pseudo_cov_par = self._parameterize_full(param)
|
||||
chol = FakeModule(pseudo_cov_par)
|
||||
elif self.par_strength == self.cov_strength:
|
||||
if self.par_strength == Strength.NONE:
|
||||
chol = FakeModule(th.ones(self.action_dim))
|
||||
elif self.par_strength == Strength.SCALAR:
|
||||
# TODO: Does it work like this? Test!
|
||||
if self.par_strength == Strength.SCALAR:
|
||||
std = nn.Linear(latent_dim, 1)
|
||||
chol = th.ones(self.action_dim) * std
|
||||
diag_chol = th.ones(self.action_dim) * std
|
||||
chol = self._ensure_positive_func(diag_chol)
|
||||
elif self.par_strength == Strength.DIAG:
|
||||
chol = nn.Linear(latent_dim, self.action_dim)
|
||||
diag_chol = nn.Linear(latent_dim, self.action_dim)
|
||||
chol = self._ensure_positive_func(diag_chol)
|
||||
elif self.par_strength == Strength.FULL:
|
||||
chol = self._parameterize_full(latent_dim)
|
||||
params = nn.Linear(latent_dim, self._full_params_len)
|
||||
chol = self._parameterize_full(params)
|
||||
elif self.par_strength > self.cov_strength:
|
||||
raise Exception(
|
||||
'The parameterization can not be stronger than the actual covariance.')
|
||||
@ -171,52 +177,95 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
raise Exception(
|
||||
'That does not even make any sense...')
|
||||
else:
|
||||
raise Exception("This Exception can't happen (I think)")
|
||||
raise Exception("This Exception can't happen")
|
||||
|
||||
return mean_actions, chol
|
||||
|
||||
def _parameterize_full(self, latent_dim):
|
||||
# TODO: Implement various techniques for full parameterization (forcing SPD)
|
||||
raise Exception(
|
||||
'Programmer-was-to-lazy-to-implement-this-Exception')
|
||||
@property
|
||||
def _full_params_len(self):
|
||||
if self.par_type == ParametrizationType.CHOL:
|
||||
return self._flat_chol_len
|
||||
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||
return self._flat_chol_len
|
||||
raise Exception()
|
||||
|
||||
def _parameterize_hybrid_from_diag(self, latent_dim):
|
||||
def _parameterize_full(self, params):
|
||||
if self.par_type == ParametrizationType.CHOL:
|
||||
return self._chol_from_flat(params)
|
||||
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||
return self._chol_from_flat_sphe_chol(params)
|
||||
raise Exception()
|
||||
|
||||
def _parameterize_hybrid_from_diag(self, params):
|
||||
# TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix)
|
||||
raise Exception(
|
||||
'Programmer-was-to-lazy-to-implement-this-Exception')
|
||||
|
||||
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
||||
# SCALAR => DIAG
|
||||
factor = nn.Linear(latent_dim, 1)
|
||||
par = th.ones(self.action_dim) * \
|
||||
nn.Parameter(1, requires_grad=True)
|
||||
diag_chol = self._ensure_positive_func(par * factor[0])
|
||||
return diag_chol
|
||||
|
||||
def _chol_from_flat(self, flat_chol):
|
||||
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
||||
return self._ensure_diagonal_positive(chol)
|
||||
|
||||
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
||||
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
||||
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
||||
self._flat_chol_len, -1, -1)
|
||||
chol = self._chol_from_sphe_chol(sphe_chol)
|
||||
return chol
|
||||
|
||||
def _chol_from_sphe_chol(self, sphe_chol):
|
||||
# TODO: Test with batched data
|
||||
# TODO: Make efficient
|
||||
# Note:
|
||||
# We must should ensure:
|
||||
# S[i,1] > 0 where i = 1..n
|
||||
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
||||
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||
S = sphe_chol
|
||||
n = self.action_dim
|
||||
L = th.zeros_like(sphe_chol)
|
||||
for i in range(n):
|
||||
for j in range(i):
|
||||
t = S[i, 1]
|
||||
for k in range(1, j+1):
|
||||
t *= th.sin(th.tanh(S[i, k])*pi)
|
||||
if i != j:
|
||||
t *= th.cos(th.tanh(S[i, j+1])*pi)
|
||||
L[i, j] = t
|
||||
return L
|
||||
|
||||
def _ensure_positive_func(self, x):
|
||||
return self.enforce_positive_type.apply(x)
|
||||
|
||||
def _ensure_diagonal_positive(self, pseudo_chol):
|
||||
pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2,
|
||||
dim2=-1)).diag_embed() + pseudo_chol.triu(1)
|
||||
def _ensure_diagonal_positive(self, chol):
|
||||
if len(chol.shape) == 1:
|
||||
# If our chol is a vector (representing a diagonal chol)
|
||||
return self._ensure_positive_func(chol)
|
||||
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
|
||||
dim2=-1)).diag_embed() + chol.triu(1)
|
||||
|
||||
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
||||
factor = nn.Linear(latent_dim, 1)
|
||||
par_cov = th.ones(self.action_dim) * \
|
||||
nn.Parameter(1, requires_grad=True)
|
||||
pseudo_cov = par_cov * factor[0]
|
||||
return pseudo_cov
|
||||
|
||||
def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution":
|
||||
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor) -> "UniversalGaussianDistribution":
|
||||
"""
|
||||
Create the distribution given its parameters (mean, pseudo_cov)
|
||||
Create the distribution given its parameters (mean, chol)
|
||||
|
||||
:param mean_actions:
|
||||
:param pseudo_cov:
|
||||
:param chol:
|
||||
:return:
|
||||
"""
|
||||
action_std = None
|
||||
# TODO: Needs to be expanded
|
||||
if self.cov_strength == Strength.DIAG:
|
||||
if self.enforce_positive_type == EnforcePositiveType.LOG:
|
||||
action_std = pseudo_cov.exp()
|
||||
if action_std == None:
|
||||
raise Exception('Not yet implemented!')
|
||||
self.distribution = Normal(mean_actions, action_std)
|
||||
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
||||
self.distribution = Normal(mean_actions, chol)
|
||||
elif self.cov_strength in [Strength.FULL]:
|
||||
self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
|
||||
if self.distribution == None:
|
||||
raise Exception('Not yet implemented!')
|
||||
raise Exception('Unable to create torch distribution')
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
|
150
metastable_baselines/misc/tensor_ops.py
Normal file
150
metastable_baselines/misc/tensor_ops.py
Normal file
@ -0,0 +1,150 @@
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
|
||||
def fill_triangular(x, upper=False):
|
||||
"""
|
||||
The following function is derived from TensorFlow Probability
|
||||
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L784
|
||||
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
|
||||
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
|
||||
Creates a (batch of) triangular matrix from a vector of inputs.
|
||||
Created matrix can be lower- or upper-triangular. (It is more efficient to
|
||||
create the matrix as upper or lower, rather than transpose.)
|
||||
Triangular matrix elements are filled in a clockwise spiral. See example,
|
||||
below.
|
||||
If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
|
||||
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
|
||||
Example:
|
||||
```python
|
||||
fill_triangular([1, 2, 3, 4, 5, 6])
|
||||
# ==> [[4, 0, 0],
|
||||
# [6, 5, 0],
|
||||
# [3, 2, 1]]
|
||||
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
|
||||
# ==> [[1, 2, 3],
|
||||
# [0, 5, 6],
|
||||
# [0, 0, 4]]
|
||||
```
|
||||
The key trick is to create an upper triangular matrix by concatenating `x`
|
||||
and a tail of itself, then reshaping.
|
||||
Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
|
||||
from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
|
||||
contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
|
||||
(so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
|
||||
the first (`n = 5`) elements removed and reversed:
|
||||
```python
|
||||
x = np.arange(15) + 1
|
||||
xc = np.concatenate([x, x[5:][::-1]])
|
||||
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
|
||||
# 12, 11, 10, 9, 8, 7, 6])
|
||||
# (We add one to the arange result to disambiguate the zeros below the
|
||||
# diagonal of our upper-triangular matrix from the first entry in `x`.)
|
||||
# Now, when reshapedlay this out as a matrix:
|
||||
y = np.reshape(xc, [5, 5])
|
||||
# ==> array([[ 1, 2, 3, 4, 5],
|
||||
# [ 6, 7, 8, 9, 10],
|
||||
# [11, 12, 13, 14, 15],
|
||||
# [15, 14, 13, 12, 11],
|
||||
# [10, 9, 8, 7, 6]])
|
||||
# Finally, zero the elements below the diagonal:
|
||||
y = np.triu(y, k=0)
|
||||
# ==> array([[ 1, 2, 3, 4, 5],
|
||||
# [ 0, 7, 8, 9, 10],
|
||||
# [ 0, 0, 13, 14, 15],
|
||||
# [ 0, 0, 0, 12, 11],
|
||||
# [ 0, 0, 0, 0, 6]])
|
||||
```
|
||||
From this example we see that the resuting matrix is upper-triangular, and
|
||||
contains all the entries of x, as desired. The rest is details:
|
||||
- If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
|
||||
`n / 2` rows and half of an additional row), but the whole scheme still
|
||||
works.
|
||||
- If we want a lower triangular matrix instead of an upper triangular,
|
||||
we remove the first `n` elements from `x` rather than from the reversed
|
||||
`x`.
|
||||
For additional comparisons, a pure numpy version of this function can be found
|
||||
in `distribution_util_test.py`, function `_fill_triangular`.
|
||||
Args:
|
||||
x: `Tensor` representing lower (or upper) triangular elements.
|
||||
upper: Python `bool` representing whether output matrix should be upper
|
||||
triangular (`True`) or lower triangular (`False`, default).
|
||||
Returns:
|
||||
tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
|
||||
Raises:
|
||||
ValueError: if `x` cannot be mapped to a triangular matrix.
|
||||
"""
|
||||
|
||||
m = np.int32(x.shape[-1])
|
||||
# Formula derived by solving for n: m = n(n+1)/2.
|
||||
n = np.sqrt(0.25 + 2. * m) - 0.5
|
||||
if n != np.floor(n):
|
||||
raise ValueError('Input right-most shape ({}) does not '
|
||||
'correspond to a triangular matrix.'.format(m))
|
||||
n = np.int32(n)
|
||||
new_shape = x.shape[:-1] + (n, n)
|
||||
|
||||
ndims = len(x.shape)
|
||||
if upper:
|
||||
x_list = [x, th.flip(x[..., n:], dims=[ndims - 1])]
|
||||
else:
|
||||
x_list = [x[..., n:], th.flip(x, dims=[ndims - 1])]
|
||||
|
||||
x = th.cat(x_list, dim=-1).reshape(new_shape)
|
||||
x = th.triu(x) if upper else th.tril(x)
|
||||
return x
|
||||
|
||||
|
||||
def fill_triangular_inverse(x, upper=False):
|
||||
"""
|
||||
The following function is derived from TensorFlow Probability
|
||||
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L934
|
||||
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
|
||||
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
|
||||
Creates a vector from a (batch of) triangular matrix.
|
||||
The vector is created from the lower-triangular or upper-triangular portion
|
||||
depending on the value of the parameter `upper`.
|
||||
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
|
||||
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
|
||||
Example:
|
||||
```python
|
||||
fill_triangular_inverse(
|
||||
[[4, 0, 0],
|
||||
[6, 5, 0],
|
||||
[3, 2, 1]])
|
||||
# ==> [1, 2, 3, 4, 5, 6]
|
||||
fill_triangular_inverse(
|
||||
[[1, 2, 3],
|
||||
[0, 5, 6],
|
||||
[0, 0, 4]], upper=True)
|
||||
# ==> [1, 2, 3, 4, 5, 6]
|
||||
```
|
||||
Args:
|
||||
x: `Tensor` representing lower (or upper) triangular elements.
|
||||
upper: Python `bool` representing whether output matrix should be upper
|
||||
triangular (`True`) or lower triangular (`False`, default).
|
||||
Returns:
|
||||
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
|
||||
(or upper) triangular elements from `x`.
|
||||
"""
|
||||
|
||||
n = np.int32(x.shape[-1])
|
||||
m = np.int32((n * (n + 1)) // 2)
|
||||
|
||||
ndims = len(x.shape)
|
||||
if upper:
|
||||
initial_elements = x[..., 0, :]
|
||||
triangular_part = x[..., 1:, :]
|
||||
else:
|
||||
initial_elements = ch.flip(x[..., -1, :], dims=[ndims - 2])
|
||||
triangular_part = x[..., :-1, :]
|
||||
|
||||
rotated_triangular_portion = ch.flip(
|
||||
th.flip(triangular_part, dims=[ndims - 1]), dims=[ndims - 2])
|
||||
consolidated_matrix = triangular_part + rotated_triangular_portion
|
||||
|
||||
end_sequence = consolidated_matrix.reshape(x.shape[:-2] + (n * (n - 1),))
|
||||
|
||||
y = th.cat([initial_elements, end_sequence[..., :m - n]], dim=-1)
|
||||
return y
|
Loading…
Reference in New Issue
Block a user