Fixes + spherical_chol
This commit is contained in:
parent
e4440428f8
commit
41e4170b2f
@ -10,6 +10,8 @@ such Third Party IP, are set forth below.
|
|||||||
Overview
|
Overview
|
||||||
--------------------------------------------------------------------------
|
--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# TODO: Tensorflow-Probability
|
||||||
|
|
||||||
# TODO: TrustRegionLayers
|
# TODO: TrustRegionLayers
|
||||||
|
|
||||||
boschresearch/trust-region-layers
|
boschresearch/trust-region-layers
|
||||||
|
@ -4,6 +4,7 @@ from enum import Enum
|
|||||||
import torch as th
|
import torch as th
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributions import Normal, MultivariateNormal
|
from torch.distributions import Normal, MultivariateNormal
|
||||||
|
from math import pi
|
||||||
|
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
|
|||||||
|
|
||||||
from ..misc.fakeModule import FakeModule
|
from ..misc.fakeModule import FakeModule
|
||||||
from ..misc.distTools import new_dist_like
|
from ..misc.distTools import new_dist_like
|
||||||
|
from ..misc.tensor_ops import fill_triangular
|
||||||
|
|
||||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||||
# TODO: Support Squashed Dists (tanh)
|
# TODO: Support Squashed Dists (tanh)
|
||||||
@ -30,9 +32,8 @@ class Strength(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class ParametrizationType(Enum):
|
class ParametrizationType(Enum):
|
||||||
# Currently only Chol is implemented
|
|
||||||
CHOL = 1
|
CHOL = 1
|
||||||
#SPHERICAL_CHOL = 2
|
SPHERICAL_CHOL = 2
|
||||||
#GIVENS = 3
|
#GIVENS = 3
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +78,9 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
|||||||
if ps == Strength.SCALAR and cs == Strength.FULL:
|
if ps == Strength.SCALAR and cs == Strength.FULL:
|
||||||
# TODO: Maybe allow?
|
# TODO: Maybe allow?
|
||||||
continue
|
continue
|
||||||
|
if ps == Strength.DIAG and cs == Strength.FULL:
|
||||||
|
# TODO: Implement
|
||||||
|
continue
|
||||||
if ps == Strength.NONE:
|
if ps == Strength.NONE:
|
||||||
yield (ps, cs, None, None)
|
yield (ps, cs, None, None)
|
||||||
else:
|
else:
|
||||||
@ -95,70 +99,72 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param action_dim: Dimension of the action space.
|
:param action_dim: Dimension of the action space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action_dim: int):
|
def __init__(self, action_dim: int, neural_strength=Strength.DIAG, cov_strength=Strength.DIAG, parameterization_type=Strength.CHOL, enforce_positive_type=EnforcePositiveType.ABS, prob_squashing_type=ProbSquashingType.TANH):
|
||||||
super(UniversalGaussianDistribution, self).__init__()
|
super(UniversalGaussianDistribution, self).__init__()
|
||||||
self.par_strength = Strength.DIAG
|
self.par_strength = neural_strength
|
||||||
self.cov_strength = Strength.DIAG
|
self.cov_strength = cov_strength
|
||||||
self.par_type = ParametrizationType.CHOL
|
self.par_type = parameterization_type
|
||||||
self.enforce_positive_type = EnforcePositiveType.LOG
|
self.enforce_positive_type = enforce_positive_type
|
||||||
self.prob_squashing_type = ProbSquashingType.TANH
|
self.prob_squashing_type = prob_squashing_type
|
||||||
|
|
||||||
self.distribution = None
|
self.distribution = None
|
||||||
|
|
||||||
|
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
||||||
|
|
||||||
def new_dist_like_me(self, mean, pseudo_chol):
|
def new_dist_like_me(self, mean, pseudo_chol):
|
||||||
p = self.distribution
|
p = self.distribution
|
||||||
np = new_dist_like(p, mean, pseudo_chol)
|
np = new_dist_like(p, mean, pseudo_chol)
|
||||||
new = UniversalGaussianDistribution(self.action_dim)
|
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
||||||
new.par_strength = self.par_strength
|
parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
|
||||||
new.cov_strength = self.cov_strength
|
|
||||||
new.par_type = self.par_type
|
|
||||||
new.enforce_positive_type = self.enforce_positive_type
|
|
||||||
new.prob_squashing_type = self.prob_squashing_type
|
|
||||||
new.distribution = np
|
new.distribution = np
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
def proba_distribution_net(self, latent_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
||||||
"""
|
"""
|
||||||
Create the layers and parameter that represent the distribution:
|
Create the layers and parameter that represent the distribution:
|
||||||
one output will be the mean of the Gaussian, the other parameter will be the
|
one output will be the mean of the Gaussian, the other parameter will be the
|
||||||
standard deviation (log std in fact to allow negative values)
|
standard deviation
|
||||||
|
|
||||||
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param std_init: Initial value for the standard deviation
|
||||||
:return: We return two nn.Modules (mean, chol).
|
:return: We return two nn.Modules (mean, chol). chol can be a vector if the full chol would be a diagonal.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert std_init >= 0.0, "std can not be initialized to a negative value."
|
||||||
|
|
||||||
# TODO: Allow chol to be vector when only diagonal.
|
# TODO: Allow chol to be vector when only diagonal.
|
||||||
|
|
||||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||||
|
|
||||||
if self.par_strength == Strength.NONE:
|
if self.par_strength == Strength.NONE:
|
||||||
if self.cov_strength == Strength.NONE:
|
if self.cov_strength == Strength.NONE:
|
||||||
pseudo_cov_par = th.ones(self.action_dim) * log_std_init
|
pseudo_cov_par = th.ones(self.action_dim) * std_init
|
||||||
elif self.cov_strength == Strength.SCALAR:
|
elif self.cov_strength == Strength.SCALAR:
|
||||||
pseudo_cov_par = th.ones(self.action_dim) * \
|
pseudo_cov_par = th.ones(self.action_dim) * \
|
||||||
nn.Parameter(log_std_init, requires_grad=True)
|
nn.Parameter(std_init, requires_grad=True)
|
||||||
|
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
|
||||||
elif self.cov_strength == Strength.DIAG:
|
elif self.cov_strength == Strength.DIAG:
|
||||||
pseudo_cov_par = nn.Parameter(
|
pseudo_cov_par = nn.Parameter(
|
||||||
th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
th.ones(self.action_dim) * std_init, requires_grad=True)
|
||||||
|
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
|
||||||
elif self.cov_strength == Strength.FULL:
|
elif self.cov_strength == Strength.FULL:
|
||||||
# TODO: This won't work, need to ensure SPD!
|
# TODO: Init Off-axis differently?
|
||||||
# TODO: Off-axis init?
|
param = nn.Parameter(
|
||||||
pseudo_cov_par = nn.Parameter(
|
th.ones(self._full_params_len) * std_init, requires_grad=True)
|
||||||
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
|
pseudo_cov_par = self._parameterize_full(param)
|
||||||
chol = FakeModule(pseudo_cov_par)
|
chol = FakeModule(pseudo_cov_par)
|
||||||
elif self.par_strength == self.cov_strength:
|
elif self.par_strength == self.cov_strength:
|
||||||
if self.par_strength == Strength.NONE:
|
if self.par_strength == Strength.SCALAR:
|
||||||
chol = FakeModule(th.ones(self.action_dim))
|
|
||||||
elif self.par_strength == Strength.SCALAR:
|
|
||||||
# TODO: Does it work like this? Test!
|
|
||||||
std = nn.Linear(latent_dim, 1)
|
std = nn.Linear(latent_dim, 1)
|
||||||
chol = th.ones(self.action_dim) * std
|
diag_chol = th.ones(self.action_dim) * std
|
||||||
|
chol = self._ensure_positive_func(diag_chol)
|
||||||
elif self.par_strength == Strength.DIAG:
|
elif self.par_strength == Strength.DIAG:
|
||||||
chol = nn.Linear(latent_dim, self.action_dim)
|
diag_chol = nn.Linear(latent_dim, self.action_dim)
|
||||||
|
chol = self._ensure_positive_func(diag_chol)
|
||||||
elif self.par_strength == Strength.FULL:
|
elif self.par_strength == Strength.FULL:
|
||||||
chol = self._parameterize_full(latent_dim)
|
params = nn.Linear(latent_dim, self._full_params_len)
|
||||||
|
chol = self._parameterize_full(params)
|
||||||
elif self.par_strength > self.cov_strength:
|
elif self.par_strength > self.cov_strength:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'The parameterization can not be stronger than the actual covariance.')
|
'The parameterization can not be stronger than the actual covariance.')
|
||||||
@ -171,52 +177,95 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
'That does not even make any sense...')
|
'That does not even make any sense...')
|
||||||
else:
|
else:
|
||||||
raise Exception("This Exception can't happen (I think)")
|
raise Exception("This Exception can't happen")
|
||||||
|
|
||||||
return mean_actions, chol
|
return mean_actions, chol
|
||||||
|
|
||||||
def _parameterize_full(self, latent_dim):
|
@property
|
||||||
# TODO: Implement various techniques for full parameterization (forcing SPD)
|
def _full_params_len(self):
|
||||||
raise Exception(
|
if self.par_type == ParametrizationType.CHOL:
|
||||||
'Programmer-was-to-lazy-to-implement-this-Exception')
|
return self._flat_chol_len
|
||||||
|
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||||
|
return self._flat_chol_len
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
def _parameterize_hybrid_from_diag(self, latent_dim):
|
def _parameterize_full(self, params):
|
||||||
|
if self.par_type == ParametrizationType.CHOL:
|
||||||
|
return self._chol_from_flat(params)
|
||||||
|
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||||
|
return self._chol_from_flat_sphe_chol(params)
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
def _parameterize_hybrid_from_diag(self, params):
|
||||||
# TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix)
|
# TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Programmer-was-to-lazy-to-implement-this-Exception')
|
'Programmer-was-to-lazy-to-implement-this-Exception')
|
||||||
|
|
||||||
|
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
||||||
|
# SCALAR => DIAG
|
||||||
|
factor = nn.Linear(latent_dim, 1)
|
||||||
|
par = th.ones(self.action_dim) * \
|
||||||
|
nn.Parameter(1, requires_grad=True)
|
||||||
|
diag_chol = self._ensure_positive_func(par * factor[0])
|
||||||
|
return diag_chol
|
||||||
|
|
||||||
|
def _chol_from_flat(self, flat_chol):
|
||||||
|
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
||||||
|
return self._ensure_diagonal_positive(chol)
|
||||||
|
|
||||||
|
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
||||||
|
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
||||||
|
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
||||||
|
self._flat_chol_len, -1, -1)
|
||||||
|
chol = self._chol_from_sphe_chol(sphe_chol)
|
||||||
|
return chol
|
||||||
|
|
||||||
|
def _chol_from_sphe_chol(self, sphe_chol):
|
||||||
|
# TODO: Test with batched data
|
||||||
|
# TODO: Make efficient
|
||||||
|
# Note:
|
||||||
|
# We must should ensure:
|
||||||
|
# S[i,1] > 0 where i = 1..n
|
||||||
|
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
||||||
|
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||||
|
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||||
|
S = sphe_chol
|
||||||
|
n = self.action_dim
|
||||||
|
L = th.zeros_like(sphe_chol)
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(i):
|
||||||
|
t = S[i, 1]
|
||||||
|
for k in range(1, j+1):
|
||||||
|
t *= th.sin(th.tanh(S[i, k])*pi)
|
||||||
|
if i != j:
|
||||||
|
t *= th.cos(th.tanh(S[i, j+1])*pi)
|
||||||
|
L[i, j] = t
|
||||||
|
return L
|
||||||
|
|
||||||
def _ensure_positive_func(self, x):
|
def _ensure_positive_func(self, x):
|
||||||
return self.enforce_positive_type.apply(x)
|
return self.enforce_positive_type.apply(x)
|
||||||
|
|
||||||
def _ensure_diagonal_positive(self, pseudo_chol):
|
def _ensure_diagonal_positive(self, chol):
|
||||||
pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2,
|
if len(chol.shape) == 1:
|
||||||
dim2=-1)).diag_embed() + pseudo_chol.triu(1)
|
# If our chol is a vector (representing a diagonal chol)
|
||||||
|
return self._ensure_positive_func(chol)
|
||||||
|
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
|
||||||
|
dim2=-1)).diag_embed() + chol.triu(1)
|
||||||
|
|
||||||
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor) -> "UniversalGaussianDistribution":
|
||||||
factor = nn.Linear(latent_dim, 1)
|
|
||||||
par_cov = th.ones(self.action_dim) * \
|
|
||||||
nn.Parameter(1, requires_grad=True)
|
|
||||||
pseudo_cov = par_cov * factor[0]
|
|
||||||
return pseudo_cov
|
|
||||||
|
|
||||||
def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution":
|
|
||||||
"""
|
"""
|
||||||
Create the distribution given its parameters (mean, pseudo_cov)
|
Create the distribution given its parameters (mean, chol)
|
||||||
|
|
||||||
:param mean_actions:
|
:param mean_actions:
|
||||||
:param pseudo_cov:
|
:param chol:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
action_std = None
|
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
||||||
# TODO: Needs to be expanded
|
self.distribution = Normal(mean_actions, chol)
|
||||||
if self.cov_strength == Strength.DIAG:
|
elif self.cov_strength in [Strength.FULL]:
|
||||||
if self.enforce_positive_type == EnforcePositiveType.LOG:
|
self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
|
||||||
action_std = pseudo_cov.exp()
|
|
||||||
if action_std == None:
|
|
||||||
raise Exception('Not yet implemented!')
|
|
||||||
self.distribution = Normal(mean_actions, action_std)
|
|
||||||
if self.distribution == None:
|
if self.distribution == None:
|
||||||
raise Exception('Not yet implemented!')
|
raise Exception('Unable to create torch distribution')
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||||
|
150
metastable_baselines/misc/tensor_ops.py
Normal file
150
metastable_baselines/misc/tensor_ops.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import torch as th
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def fill_triangular(x, upper=False):
|
||||||
|
"""
|
||||||
|
The following function is derived from TensorFlow Probability
|
||||||
|
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L784
|
||||||
|
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
|
||||||
|
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
|
||||||
|
Creates a (batch of) triangular matrix from a vector of inputs.
|
||||||
|
Created matrix can be lower- or upper-triangular. (It is more efficient to
|
||||||
|
create the matrix as upper or lower, rather than transpose.)
|
||||||
|
Triangular matrix elements are filled in a clockwise spiral. See example,
|
||||||
|
below.
|
||||||
|
If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
|
||||||
|
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||||
|
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
fill_triangular([1, 2, 3, 4, 5, 6])
|
||||||
|
# ==> [[4, 0, 0],
|
||||||
|
# [6, 5, 0],
|
||||||
|
# [3, 2, 1]]
|
||||||
|
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
|
||||||
|
# ==> [[1, 2, 3],
|
||||||
|
# [0, 5, 6],
|
||||||
|
# [0, 0, 4]]
|
||||||
|
```
|
||||||
|
The key trick is to create an upper triangular matrix by concatenating `x`
|
||||||
|
and a tail of itself, then reshaping.
|
||||||
|
Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
|
||||||
|
from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
|
||||||
|
contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
|
||||||
|
(so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
|
||||||
|
the first (`n = 5`) elements removed and reversed:
|
||||||
|
```python
|
||||||
|
x = np.arange(15) + 1
|
||||||
|
xc = np.concatenate([x, x[5:][::-1]])
|
||||||
|
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
|
||||||
|
# 12, 11, 10, 9, 8, 7, 6])
|
||||||
|
# (We add one to the arange result to disambiguate the zeros below the
|
||||||
|
# diagonal of our upper-triangular matrix from the first entry in `x`.)
|
||||||
|
# Now, when reshapedlay this out as a matrix:
|
||||||
|
y = np.reshape(xc, [5, 5])
|
||||||
|
# ==> array([[ 1, 2, 3, 4, 5],
|
||||||
|
# [ 6, 7, 8, 9, 10],
|
||||||
|
# [11, 12, 13, 14, 15],
|
||||||
|
# [15, 14, 13, 12, 11],
|
||||||
|
# [10, 9, 8, 7, 6]])
|
||||||
|
# Finally, zero the elements below the diagonal:
|
||||||
|
y = np.triu(y, k=0)
|
||||||
|
# ==> array([[ 1, 2, 3, 4, 5],
|
||||||
|
# [ 0, 7, 8, 9, 10],
|
||||||
|
# [ 0, 0, 13, 14, 15],
|
||||||
|
# [ 0, 0, 0, 12, 11],
|
||||||
|
# [ 0, 0, 0, 0, 6]])
|
||||||
|
```
|
||||||
|
From this example we see that the resuting matrix is upper-triangular, and
|
||||||
|
contains all the entries of x, as desired. The rest is details:
|
||||||
|
- If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
|
||||||
|
`n / 2` rows and half of an additional row), but the whole scheme still
|
||||||
|
works.
|
||||||
|
- If we want a lower triangular matrix instead of an upper triangular,
|
||||||
|
we remove the first `n` elements from `x` rather than from the reversed
|
||||||
|
`x`.
|
||||||
|
For additional comparisons, a pure numpy version of this function can be found
|
||||||
|
in `distribution_util_test.py`, function `_fill_triangular`.
|
||||||
|
Args:
|
||||||
|
x: `Tensor` representing lower (or upper) triangular elements.
|
||||||
|
upper: Python `bool` representing whether output matrix should be upper
|
||||||
|
triangular (`True`) or lower triangular (`False`, default).
|
||||||
|
Returns:
|
||||||
|
tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x` cannot be mapped to a triangular matrix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
m = np.int32(x.shape[-1])
|
||||||
|
# Formula derived by solving for n: m = n(n+1)/2.
|
||||||
|
n = np.sqrt(0.25 + 2. * m) - 0.5
|
||||||
|
if n != np.floor(n):
|
||||||
|
raise ValueError('Input right-most shape ({}) does not '
|
||||||
|
'correspond to a triangular matrix.'.format(m))
|
||||||
|
n = np.int32(n)
|
||||||
|
new_shape = x.shape[:-1] + (n, n)
|
||||||
|
|
||||||
|
ndims = len(x.shape)
|
||||||
|
if upper:
|
||||||
|
x_list = [x, th.flip(x[..., n:], dims=[ndims - 1])]
|
||||||
|
else:
|
||||||
|
x_list = [x[..., n:], th.flip(x, dims=[ndims - 1])]
|
||||||
|
|
||||||
|
x = th.cat(x_list, dim=-1).reshape(new_shape)
|
||||||
|
x = th.triu(x) if upper else th.tril(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def fill_triangular_inverse(x, upper=False):
|
||||||
|
"""
|
||||||
|
The following function is derived from TensorFlow Probability
|
||||||
|
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L934
|
||||||
|
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
|
||||||
|
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
|
||||||
|
Creates a vector from a (batch of) triangular matrix.
|
||||||
|
The vector is created from the lower-triangular or upper-triangular portion
|
||||||
|
depending on the value of the parameter `upper`.
|
||||||
|
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
|
||||||
|
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
fill_triangular_inverse(
|
||||||
|
[[4, 0, 0],
|
||||||
|
[6, 5, 0],
|
||||||
|
[3, 2, 1]])
|
||||||
|
# ==> [1, 2, 3, 4, 5, 6]
|
||||||
|
fill_triangular_inverse(
|
||||||
|
[[1, 2, 3],
|
||||||
|
[0, 5, 6],
|
||||||
|
[0, 0, 4]], upper=True)
|
||||||
|
# ==> [1, 2, 3, 4, 5, 6]
|
||||||
|
```
|
||||||
|
Args:
|
||||||
|
x: `Tensor` representing lower (or upper) triangular elements.
|
||||||
|
upper: Python `bool` representing whether output matrix should be upper
|
||||||
|
triangular (`True`) or lower triangular (`False`, default).
|
||||||
|
Returns:
|
||||||
|
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
|
||||||
|
(or upper) triangular elements from `x`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n = np.int32(x.shape[-1])
|
||||||
|
m = np.int32((n * (n + 1)) // 2)
|
||||||
|
|
||||||
|
ndims = len(x.shape)
|
||||||
|
if upper:
|
||||||
|
initial_elements = x[..., 0, :]
|
||||||
|
triangular_part = x[..., 1:, :]
|
||||||
|
else:
|
||||||
|
initial_elements = ch.flip(x[..., -1, :], dims=[ndims - 2])
|
||||||
|
triangular_part = x[..., :-1, :]
|
||||||
|
|
||||||
|
rotated_triangular_portion = ch.flip(
|
||||||
|
th.flip(triangular_part, dims=[ndims - 1]), dims=[ndims - 2])
|
||||||
|
consolidated_matrix = triangular_part + rotated_triangular_portion
|
||||||
|
|
||||||
|
end_sequence = consolidated_matrix.reshape(x.shape[:-2] + (n * (n - 1),))
|
||||||
|
|
||||||
|
y = th.cat([initial_elements, end_sequence[..., :m - n]], dim=-1)
|
||||||
|
return y
|
Loading…
Reference in New Issue
Block a user