Fixes + spherical_chol

This commit is contained in:
Dominik Moritz Roth 2022-07-11 17:28:08 +02:00
parent e4440428f8
commit 41e4170b2f
3 changed files with 261 additions and 60 deletions

View File

@ -10,6 +10,8 @@ such Third Party IP, are set forth below.
Overview
--------------------------------------------------------------------------
# TODO: Tensorflow-Probability
# TODO: TrustRegionLayers
boschresearch/trust-region-layers

View File

@ -4,6 +4,7 @@ from enum import Enum
import torch as th
from torch import nn
from torch.distributions import Normal, MultivariateNormal
from math import pi
from stable_baselines3.common.preprocessing import get_action_dim
@ -13,6 +14,7 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.fakeModule import FakeModule
from ..misc.distTools import new_dist_like
from ..misc.tensor_ops import fill_triangular
# TODO: Integrate and Test what I currently have before adding more complexity
# TODO: Support Squashed Dists (tanh)
@ -30,9 +32,8 @@ class Strength(Enum):
class ParametrizationType(Enum):
# Currently only Chol is implemented
CHOL = 1
#SPHERICAL_CHOL = 2
SPHERICAL_CHOL = 2
#GIVENS = 3
@ -77,6 +78,9 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
if ps == Strength.SCALAR and cs == Strength.FULL:
# TODO: Maybe allow?
continue
if ps == Strength.DIAG and cs == Strength.FULL:
# TODO: Implement
continue
if ps == Strength.NONE:
yield (ps, cs, None, None)
else:
@ -95,70 +99,72 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param action_dim: Dimension of the action space.
"""
def __init__(self, action_dim: int):
def __init__(self, action_dim: int, neural_strength=Strength.DIAG, cov_strength=Strength.DIAG, parameterization_type=Strength.CHOL, enforce_positive_type=EnforcePositiveType.ABS, prob_squashing_type=ProbSquashingType.TANH):
super(UniversalGaussianDistribution, self).__init__()
self.par_strength = Strength.DIAG
self.cov_strength = Strength.DIAG
self.par_type = ParametrizationType.CHOL
self.enforce_positive_type = EnforcePositiveType.LOG
self.prob_squashing_type = ProbSquashingType.TANH
self.par_strength = neural_strength
self.cov_strength = cov_strength
self.par_type = parameterization_type
self.enforce_positive_type = enforce_positive_type
self.prob_squashing_type = prob_squashing_type
self.distribution = None
self._flat_chol_len = action_dim * (action_dim + 1) // 2
def new_dist_like_me(self, mean, pseudo_chol):
p = self.distribution
np = new_dist_like(p, mean, pseudo_chol)
new = UniversalGaussianDistribution(self.action_dim)
new.par_strength = self.par_strength
new.cov_strength = self.cov_strength
new.par_type = self.par_type
new.enforce_positive_type = self.enforce_positive_type
new.prob_squashing_type = self.prob_squashing_type
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
new.distribution = np
return new
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
def proba_distribution_net(self, latent_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)
standard deviation
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:return: We return two nn.Modules (mean, chol).
:param std_init: Initial value for the standard deviation
:return: We return two nn.Modules (mean, chol). chol can be a vector if the full chol would be a diagonal.
"""
assert std_init >= 0.0, "std can not be initialized to a negative value."
# TODO: Allow chol to be vector when only diagonal.
mean_actions = nn.Linear(latent_dim, self.action_dim)
if self.par_strength == Strength.NONE:
if self.cov_strength == Strength.NONE:
pseudo_cov_par = th.ones(self.action_dim) * log_std_init
pseudo_cov_par = th.ones(self.action_dim) * std_init
elif self.cov_strength == Strength.SCALAR:
pseudo_cov_par = th.ones(self.action_dim) * \
nn.Parameter(log_std_init, requires_grad=True)
nn.Parameter(std_init, requires_grad=True)
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
elif self.cov_strength == Strength.DIAG:
pseudo_cov_par = nn.Parameter(
th.ones(self.action_dim) * log_std_init, requires_grad=True)
th.ones(self.action_dim) * std_init, requires_grad=True)
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
elif self.cov_strength == Strength.FULL:
# TODO: This won't work, need to ensure SPD!
# TODO: Off-axis init?
pseudo_cov_par = nn.Parameter(
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
# TODO: Init Off-axis differently?
param = nn.Parameter(
th.ones(self._full_params_len) * std_init, requires_grad=True)
pseudo_cov_par = self._parameterize_full(param)
chol = FakeModule(pseudo_cov_par)
elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.NONE:
chol = FakeModule(th.ones(self.action_dim))
elif self.par_strength == Strength.SCALAR:
# TODO: Does it work like this? Test!
if self.par_strength == Strength.SCALAR:
std = nn.Linear(latent_dim, 1)
chol = th.ones(self.action_dim) * std
diag_chol = th.ones(self.action_dim) * std
chol = self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.DIAG:
chol = nn.Linear(latent_dim, self.action_dim)
diag_chol = nn.Linear(latent_dim, self.action_dim)
chol = self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.FULL:
chol = self._parameterize_full(latent_dim)
params = nn.Linear(latent_dim, self._full_params_len)
chol = self._parameterize_full(params)
elif self.par_strength > self.cov_strength:
raise Exception(
'The parameterization can not be stronger than the actual covariance.')
@ -171,52 +177,95 @@ class UniversalGaussianDistribution(SB3_Distribution):
raise Exception(
'That does not even make any sense...')
else:
raise Exception("This Exception can't happen (I think)")
raise Exception("This Exception can't happen")
return mean_actions, chol
def _parameterize_full(self, latent_dim):
# TODO: Implement various techniques for full parameterization (forcing SPD)
raise Exception(
'Programmer-was-to-lazy-to-implement-this-Exception')
@property
def _full_params_len(self):
if self.par_type == ParametrizationType.CHOL:
return self._flat_chol_len
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._flat_chol_len
raise Exception()
def _parameterize_hybrid_from_diag(self, latent_dim):
def _parameterize_full(self, params):
if self.par_type == ParametrizationType.CHOL:
return self._chol_from_flat(params)
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._chol_from_flat_sphe_chol(params)
raise Exception()
def _parameterize_hybrid_from_diag(self, params):
# TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix)
raise Exception(
'Programmer-was-to-lazy-to-implement-this-Exception')
def _parameterize_hybrid_from_scalar(self, latent_dim):
# SCALAR => DIAG
factor = nn.Linear(latent_dim, 1)
par = th.ones(self.action_dim) * \
nn.Parameter(1, requires_grad=True)
diag_chol = self._ensure_positive_func(par * factor[0])
return diag_chol
def _chol_from_flat(self, flat_chol):
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
self._flat_chol_len, -1, -1)
chol = self._chol_from_sphe_chol(sphe_chol)
return chol
def _chol_from_sphe_chol(self, sphe_chol):
# TODO: Test with batched data
# TODO: Make efficient
# Note:
# We must should ensure:
# S[i,1] > 0 where i = 1..n
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
# We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements
S = sphe_chol
n = self.action_dim
L = th.zeros_like(sphe_chol)
for i in range(n):
for j in range(i):
t = S[i, 1]
for k in range(1, j+1):
t *= th.sin(th.tanh(S[i, k])*pi)
if i != j:
t *= th.cos(th.tanh(S[i, j+1])*pi)
L[i, j] = t
return L
def _ensure_positive_func(self, x):
return self.enforce_positive_type.apply(x)
def _ensure_diagonal_positive(self, pseudo_chol):
pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2,
dim2=-1)).diag_embed() + pseudo_chol.triu(1)
def _ensure_diagonal_positive(self, chol):
if len(chol.shape) == 1:
# If our chol is a vector (representing a diagonal chol)
return self._ensure_positive_func(chol)
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
dim2=-1)).diag_embed() + chol.triu(1)
def _parameterize_hybrid_from_scalar(self, latent_dim):
factor = nn.Linear(latent_dim, 1)
par_cov = th.ones(self.action_dim) * \
nn.Parameter(1, requires_grad=True)
pseudo_cov = par_cov * factor[0]
return pseudo_cov
def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution":
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor) -> "UniversalGaussianDistribution":
"""
Create the distribution given its parameters (mean, pseudo_cov)
Create the distribution given its parameters (mean, chol)
:param mean_actions:
:param pseudo_cov:
:param chol:
:return:
"""
action_std = None
# TODO: Needs to be expanded
if self.cov_strength == Strength.DIAG:
if self.enforce_positive_type == EnforcePositiveType.LOG:
action_std = pseudo_cov.exp()
if action_std == None:
raise Exception('Not yet implemented!')
self.distribution = Normal(mean_actions, action_std)
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
self.distribution = Normal(mean_actions, chol)
elif self.cov_strength in [Strength.FULL]:
self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
if self.distribution == None:
raise Exception('Not yet implemented!')
raise Exception('Unable to create torch distribution')
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:

View File

@ -0,0 +1,150 @@
import torch as th
import numpy as np
def fill_triangular(x, upper=False):
"""
The following function is derived from TensorFlow Probability
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L784
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
Creates a (batch of) triangular matrix from a vector of inputs.
Created matrix can be lower- or upper-triangular. (It is more efficient to
create the matrix as upper or lower, rather than transpose.)
Triangular matrix elements are filled in a clockwise spiral. See example,
below.
If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
Example:
```python
fill_triangular([1, 2, 3, 4, 5, 6])
# ==> [[4, 0, 0],
# [6, 5, 0],
# [3, 2, 1]]
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
# ==> [[1, 2, 3],
# [0, 5, 6],
# [0, 0, 4]]
```
The key trick is to create an upper triangular matrix by concatenating `x`
and a tail of itself, then reshaping.
Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
(so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
the first (`n = 5`) elements removed and reversed:
```python
x = np.arange(15) + 1
xc = np.concatenate([x, x[5:][::-1]])
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
# 12, 11, 10, 9, 8, 7, 6])
# (We add one to the arange result to disambiguate the zeros below the
# diagonal of our upper-triangular matrix from the first entry in `x`.)
# Now, when reshapedlay this out as a matrix:
y = np.reshape(xc, [5, 5])
# ==> array([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [15, 14, 13, 12, 11],
# [10, 9, 8, 7, 6]])
# Finally, zero the elements below the diagonal:
y = np.triu(y, k=0)
# ==> array([[ 1, 2, 3, 4, 5],
# [ 0, 7, 8, 9, 10],
# [ 0, 0, 13, 14, 15],
# [ 0, 0, 0, 12, 11],
# [ 0, 0, 0, 0, 6]])
```
From this example we see that the resuting matrix is upper-triangular, and
contains all the entries of x, as desired. The rest is details:
- If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
`n / 2` rows and half of an additional row), but the whole scheme still
works.
- If we want a lower triangular matrix instead of an upper triangular,
we remove the first `n` elements from `x` rather than from the reversed
`x`.
For additional comparisons, a pure numpy version of this function can be found
in `distribution_util_test.py`, function `_fill_triangular`.
Args:
x: `Tensor` representing lower (or upper) triangular elements.
upper: Python `bool` representing whether output matrix should be upper
triangular (`True`) or lower triangular (`False`, default).
Returns:
tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
Raises:
ValueError: if `x` cannot be mapped to a triangular matrix.
"""
m = np.int32(x.shape[-1])
# Formula derived by solving for n: m = n(n+1)/2.
n = np.sqrt(0.25 + 2. * m) - 0.5
if n != np.floor(n):
raise ValueError('Input right-most shape ({}) does not '
'correspond to a triangular matrix.'.format(m))
n = np.int32(n)
new_shape = x.shape[:-1] + (n, n)
ndims = len(x.shape)
if upper:
x_list = [x, th.flip(x[..., n:], dims=[ndims - 1])]
else:
x_list = [x[..., n:], th.flip(x, dims=[ndims - 1])]
x = th.cat(x_list, dim=-1).reshape(new_shape)
x = th.triu(x) if upper else th.tril(x)
return x
def fill_triangular_inverse(x, upper=False):
"""
The following function is derived from TensorFlow Probability
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L934
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
Creates a vector from a (batch of) triangular matrix.
The vector is created from the lower-triangular or upper-triangular portion
depending on the value of the parameter `upper`.
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
Example:
```python
fill_triangular_inverse(
[[4, 0, 0],
[6, 5, 0],
[3, 2, 1]])
# ==> [1, 2, 3, 4, 5, 6]
fill_triangular_inverse(
[[1, 2, 3],
[0, 5, 6],
[0, 0, 4]], upper=True)
# ==> [1, 2, 3, 4, 5, 6]
```
Args:
x: `Tensor` representing lower (or upper) triangular elements.
upper: Python `bool` representing whether output matrix should be upper
triangular (`True`) or lower triangular (`False`, default).
Returns:
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
(or upper) triangular elements from `x`.
"""
n = np.int32(x.shape[-1])
m = np.int32((n * (n + 1)) // 2)
ndims = len(x.shape)
if upper:
initial_elements = x[..., 0, :]
triangular_part = x[..., 1:, :]
else:
initial_elements = ch.flip(x[..., -1, :], dims=[ndims - 2])
triangular_part = x[..., :-1, :]
rotated_triangular_portion = ch.flip(
th.flip(triangular_part, dims=[ndims - 1]), dims=[ndims - 2])
consolidated_matrix = triangular_part + rotated_triangular_portion
end_sequence = consolidated_matrix.reshape(x.shape[:-2] + (n * (n - 1),))
y = th.cat([initial_elements, end_sequence[..., :m - n]], dim=-1)
return y