Fixes + spherical_chol

This commit is contained in:
Dominik Moritz Roth 2022-07-11 17:28:08 +02:00
parent e4440428f8
commit 41e4170b2f
3 changed files with 261 additions and 60 deletions

View File

@ -10,6 +10,8 @@ such Third Party IP, are set forth below.
Overview Overview
-------------------------------------------------------------------------- --------------------------------------------------------------------------
# TODO: Tensorflow-Probability
# TODO: TrustRegionLayers # TODO: TrustRegionLayers
boschresearch/trust-region-layers boschresearch/trust-region-layers

View File

@ -4,6 +4,7 @@ from enum import Enum
import torch as th import torch as th
from torch import nn from torch import nn
from torch.distributions import Normal, MultivariateNormal from torch.distributions import Normal, MultivariateNormal
from math import pi
from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.preprocessing import get_action_dim
@ -13,6 +14,7 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.fakeModule import FakeModule from ..misc.fakeModule import FakeModule
from ..misc.distTools import new_dist_like from ..misc.distTools import new_dist_like
from ..misc.tensor_ops import fill_triangular
# TODO: Integrate and Test what I currently have before adding more complexity # TODO: Integrate and Test what I currently have before adding more complexity
# TODO: Support Squashed Dists (tanh) # TODO: Support Squashed Dists (tanh)
@ -30,9 +32,8 @@ class Strength(Enum):
class ParametrizationType(Enum): class ParametrizationType(Enum):
# Currently only Chol is implemented
CHOL = 1 CHOL = 1
#SPHERICAL_CHOL = 2 SPHERICAL_CHOL = 2
#GIVENS = 3 #GIVENS = 3
@ -77,6 +78,9 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
if ps == Strength.SCALAR and cs == Strength.FULL: if ps == Strength.SCALAR and cs == Strength.FULL:
# TODO: Maybe allow? # TODO: Maybe allow?
continue continue
if ps == Strength.DIAG and cs == Strength.FULL:
# TODO: Implement
continue
if ps == Strength.NONE: if ps == Strength.NONE:
yield (ps, cs, None, None) yield (ps, cs, None, None)
else: else:
@ -95,70 +99,72 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param action_dim: Dimension of the action space. :param action_dim: Dimension of the action space.
""" """
def __init__(self, action_dim: int): def __init__(self, action_dim: int, neural_strength=Strength.DIAG, cov_strength=Strength.DIAG, parameterization_type=Strength.CHOL, enforce_positive_type=EnforcePositiveType.ABS, prob_squashing_type=ProbSquashingType.TANH):
super(UniversalGaussianDistribution, self).__init__() super(UniversalGaussianDistribution, self).__init__()
self.par_strength = Strength.DIAG self.par_strength = neural_strength
self.cov_strength = Strength.DIAG self.cov_strength = cov_strength
self.par_type = ParametrizationType.CHOL self.par_type = parameterization_type
self.enforce_positive_type = EnforcePositiveType.LOG self.enforce_positive_type = enforce_positive_type
self.prob_squashing_type = ProbSquashingType.TANH self.prob_squashing_type = prob_squashing_type
self.distribution = None self.distribution = None
self._flat_chol_len = action_dim * (action_dim + 1) // 2
def new_dist_like_me(self, mean, pseudo_chol): def new_dist_like_me(self, mean, pseudo_chol):
p = self.distribution p = self.distribution
np = new_dist_like(p, mean, pseudo_chol) np = new_dist_like(p, mean, pseudo_chol)
new = UniversalGaussianDistribution(self.action_dim) new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
new.par_strength = self.par_strength parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
new.cov_strength = self.cov_strength
new.par_type = self.par_type
new.enforce_positive_type = self.enforce_positive_type
new.prob_squashing_type = self.prob_squashing_type
new.distribution = np new.distribution = np
return new return new
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: def proba_distribution_net(self, latent_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
""" """
Create the layers and parameter that represent the distribution: Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values) standard deviation
:param latent_dim: Dimension of the last layer of the policy (before the action layer) :param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation :param std_init: Initial value for the standard deviation
:return: We return two nn.Modules (mean, chol). :return: We return two nn.Modules (mean, chol). chol can be a vector if the full chol would be a diagonal.
""" """
assert std_init >= 0.0, "std can not be initialized to a negative value."
# TODO: Allow chol to be vector when only diagonal. # TODO: Allow chol to be vector when only diagonal.
mean_actions = nn.Linear(latent_dim, self.action_dim) mean_actions = nn.Linear(latent_dim, self.action_dim)
if self.par_strength == Strength.NONE: if self.par_strength == Strength.NONE:
if self.cov_strength == Strength.NONE: if self.cov_strength == Strength.NONE:
pseudo_cov_par = th.ones(self.action_dim) * log_std_init pseudo_cov_par = th.ones(self.action_dim) * std_init
elif self.cov_strength == Strength.SCALAR: elif self.cov_strength == Strength.SCALAR:
pseudo_cov_par = th.ones(self.action_dim) * \ pseudo_cov_par = th.ones(self.action_dim) * \
nn.Parameter(log_std_init, requires_grad=True) nn.Parameter(std_init, requires_grad=True)
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
elif self.cov_strength == Strength.DIAG: elif self.cov_strength == Strength.DIAG:
pseudo_cov_par = nn.Parameter( pseudo_cov_par = nn.Parameter(
th.ones(self.action_dim) * log_std_init, requires_grad=True) th.ones(self.action_dim) * std_init, requires_grad=True)
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
elif self.cov_strength == Strength.FULL: elif self.cov_strength == Strength.FULL:
# TODO: This won't work, need to ensure SPD! # TODO: Init Off-axis differently?
# TODO: Off-axis init? param = nn.Parameter(
pseudo_cov_par = nn.Parameter( th.ones(self._full_params_len) * std_init, requires_grad=True)
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True) pseudo_cov_par = self._parameterize_full(param)
chol = FakeModule(pseudo_cov_par) chol = FakeModule(pseudo_cov_par)
elif self.par_strength == self.cov_strength: elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.NONE: if self.par_strength == Strength.SCALAR:
chol = FakeModule(th.ones(self.action_dim))
elif self.par_strength == Strength.SCALAR:
# TODO: Does it work like this? Test!
std = nn.Linear(latent_dim, 1) std = nn.Linear(latent_dim, 1)
chol = th.ones(self.action_dim) * std diag_chol = th.ones(self.action_dim) * std
chol = self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.DIAG: elif self.par_strength == Strength.DIAG:
chol = nn.Linear(latent_dim, self.action_dim) diag_chol = nn.Linear(latent_dim, self.action_dim)
chol = self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.FULL: elif self.par_strength == Strength.FULL:
chol = self._parameterize_full(latent_dim) params = nn.Linear(latent_dim, self._full_params_len)
chol = self._parameterize_full(params)
elif self.par_strength > self.cov_strength: elif self.par_strength > self.cov_strength:
raise Exception( raise Exception(
'The parameterization can not be stronger than the actual covariance.') 'The parameterization can not be stronger than the actual covariance.')
@ -171,52 +177,95 @@ class UniversalGaussianDistribution(SB3_Distribution):
raise Exception( raise Exception(
'That does not even make any sense...') 'That does not even make any sense...')
else: else:
raise Exception("This Exception can't happen (I think)") raise Exception("This Exception can't happen")
return mean_actions, chol return mean_actions, chol
def _parameterize_full(self, latent_dim): @property
# TODO: Implement various techniques for full parameterization (forcing SPD) def _full_params_len(self):
raise Exception( if self.par_type == ParametrizationType.CHOL:
'Programmer-was-to-lazy-to-implement-this-Exception') return self._flat_chol_len
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._flat_chol_len
raise Exception()
def _parameterize_hybrid_from_diag(self, latent_dim): def _parameterize_full(self, params):
if self.par_type == ParametrizationType.CHOL:
return self._chol_from_flat(params)
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._chol_from_flat_sphe_chol(params)
raise Exception()
def _parameterize_hybrid_from_diag(self, params):
# TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix) # TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix)
raise Exception( raise Exception(
'Programmer-was-to-lazy-to-implement-this-Exception') 'Programmer-was-to-lazy-to-implement-this-Exception')
def _parameterize_hybrid_from_scalar(self, latent_dim):
# SCALAR => DIAG
factor = nn.Linear(latent_dim, 1)
par = th.ones(self.action_dim) * \
nn.Parameter(1, requires_grad=True)
diag_chol = self._ensure_positive_func(par * factor[0])
return diag_chol
def _chol_from_flat(self, flat_chol):
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
self._flat_chol_len, -1, -1)
chol = self._chol_from_sphe_chol(sphe_chol)
return chol
def _chol_from_sphe_chol(self, sphe_chol):
# TODO: Test with batched data
# TODO: Make efficient
# Note:
# We must should ensure:
# S[i,1] > 0 where i = 1..n
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
# We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements
S = sphe_chol
n = self.action_dim
L = th.zeros_like(sphe_chol)
for i in range(n):
for j in range(i):
t = S[i, 1]
for k in range(1, j+1):
t *= th.sin(th.tanh(S[i, k])*pi)
if i != j:
t *= th.cos(th.tanh(S[i, j+1])*pi)
L[i, j] = t
return L
def _ensure_positive_func(self, x): def _ensure_positive_func(self, x):
return self.enforce_positive_type.apply(x) return self.enforce_positive_type.apply(x)
def _ensure_diagonal_positive(self, pseudo_chol): def _ensure_diagonal_positive(self, chol):
pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2, if len(chol.shape) == 1:
dim2=-1)).diag_embed() + pseudo_chol.triu(1) # If our chol is a vector (representing a diagonal chol)
return self._ensure_positive_func(chol)
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
dim2=-1)).diag_embed() + chol.triu(1)
def _parameterize_hybrid_from_scalar(self, latent_dim): def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor) -> "UniversalGaussianDistribution":
factor = nn.Linear(latent_dim, 1)
par_cov = th.ones(self.action_dim) * \
nn.Parameter(1, requires_grad=True)
pseudo_cov = par_cov * factor[0]
return pseudo_cov
def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution":
""" """
Create the distribution given its parameters (mean, pseudo_cov) Create the distribution given its parameters (mean, chol)
:param mean_actions: :param mean_actions:
:param pseudo_cov: :param chol:
:return: :return:
""" """
action_std = None if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
# TODO: Needs to be expanded self.distribution = Normal(mean_actions, chol)
if self.cov_strength == Strength.DIAG: elif self.cov_strength in [Strength.FULL]:
if self.enforce_positive_type == EnforcePositiveType.LOG: self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
action_std = pseudo_cov.exp()
if action_std == None:
raise Exception('Not yet implemented!')
self.distribution = Normal(mean_actions, action_std)
if self.distribution == None: if self.distribution == None:
raise Exception('Not yet implemented!') raise Exception('Unable to create torch distribution')
return self return self
def log_prob(self, actions: th.Tensor) -> th.Tensor: def log_prob(self, actions: th.Tensor) -> th.Tensor:

View File

@ -0,0 +1,150 @@
import torch as th
import numpy as np
def fill_triangular(x, upper=False):
"""
The following function is derived from TensorFlow Probability
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L784
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
Creates a (batch of) triangular matrix from a vector of inputs.
Created matrix can be lower- or upper-triangular. (It is more efficient to
create the matrix as upper or lower, rather than transpose.)
Triangular matrix elements are filled in a clockwise spiral. See example,
below.
If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
Example:
```python
fill_triangular([1, 2, 3, 4, 5, 6])
# ==> [[4, 0, 0],
# [6, 5, 0],
# [3, 2, 1]]
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
# ==> [[1, 2, 3],
# [0, 5, 6],
# [0, 0, 4]]
```
The key trick is to create an upper triangular matrix by concatenating `x`
and a tail of itself, then reshaping.
Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
(so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
the first (`n = 5`) elements removed and reversed:
```python
x = np.arange(15) + 1
xc = np.concatenate([x, x[5:][::-1]])
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
# 12, 11, 10, 9, 8, 7, 6])
# (We add one to the arange result to disambiguate the zeros below the
# diagonal of our upper-triangular matrix from the first entry in `x`.)
# Now, when reshapedlay this out as a matrix:
y = np.reshape(xc, [5, 5])
# ==> array([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [15, 14, 13, 12, 11],
# [10, 9, 8, 7, 6]])
# Finally, zero the elements below the diagonal:
y = np.triu(y, k=0)
# ==> array([[ 1, 2, 3, 4, 5],
# [ 0, 7, 8, 9, 10],
# [ 0, 0, 13, 14, 15],
# [ 0, 0, 0, 12, 11],
# [ 0, 0, 0, 0, 6]])
```
From this example we see that the resuting matrix is upper-triangular, and
contains all the entries of x, as desired. The rest is details:
- If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
`n / 2` rows and half of an additional row), but the whole scheme still
works.
- If we want a lower triangular matrix instead of an upper triangular,
we remove the first `n` elements from `x` rather than from the reversed
`x`.
For additional comparisons, a pure numpy version of this function can be found
in `distribution_util_test.py`, function `_fill_triangular`.
Args:
x: `Tensor` representing lower (or upper) triangular elements.
upper: Python `bool` representing whether output matrix should be upper
triangular (`True`) or lower triangular (`False`, default).
Returns:
tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
Raises:
ValueError: if `x` cannot be mapped to a triangular matrix.
"""
m = np.int32(x.shape[-1])
# Formula derived by solving for n: m = n(n+1)/2.
n = np.sqrt(0.25 + 2. * m) - 0.5
if n != np.floor(n):
raise ValueError('Input right-most shape ({}) does not '
'correspond to a triangular matrix.'.format(m))
n = np.int32(n)
new_shape = x.shape[:-1] + (n, n)
ndims = len(x.shape)
if upper:
x_list = [x, th.flip(x[..., n:], dims=[ndims - 1])]
else:
x_list = [x[..., n:], th.flip(x, dims=[ndims - 1])]
x = th.cat(x_list, dim=-1).reshape(new_shape)
x = th.triu(x) if upper else th.tril(x)
return x
def fill_triangular_inverse(x, upper=False):
"""
The following function is derived from TensorFlow Probability
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L934
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
Creates a vector from a (batch of) triangular matrix.
The vector is created from the lower-triangular or upper-triangular portion
depending on the value of the parameter `upper`.
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
Example:
```python
fill_triangular_inverse(
[[4, 0, 0],
[6, 5, 0],
[3, 2, 1]])
# ==> [1, 2, 3, 4, 5, 6]
fill_triangular_inverse(
[[1, 2, 3],
[0, 5, 6],
[0, 0, 4]], upper=True)
# ==> [1, 2, 3, 4, 5, 6]
```
Args:
x: `Tensor` representing lower (or upper) triangular elements.
upper: Python `bool` representing whether output matrix should be upper
triangular (`True`) or lower triangular (`False`, default).
Returns:
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
(or upper) triangular elements from `x`.
"""
n = np.int32(x.shape[-1])
m = np.int32((n * (n + 1)) // 2)
ndims = len(x.shape)
if upper:
initial_elements = x[..., 0, :]
triangular_part = x[..., 1:, :]
else:
initial_elements = ch.flip(x[..., -1, :], dims=[ndims - 2])
triangular_part = x[..., :-1, :]
rotated_triangular_portion = ch.flip(
th.flip(triangular_part, dims=[ndims - 1]), dims=[ndims - 2])
consolidated_matrix = triangular_part + rotated_triangular_portion
end_sequence = consolidated_matrix.reshape(x.shape[:-2] + (n * (n - 1),))
y = th.cat([initial_elements, end_sequence[..., :m - n]], dim=-1)
return y