Port FullCov into PCA
This commit is contained in:
parent
4485e558a8
commit
220328f4b9
@ -46,6 +46,9 @@ class White_Noise():
|
|||||||
shape = self.known_shape
|
shape = self.known_shape
|
||||||
return th.Tensor(np.random.normal(0, 1, shape))
|
return th.Tensor(np.random.normal(0, 1, shape))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_colored_noise(beta, known_shape=None):
|
def get_colored_noise(beta, known_shape=None):
|
||||||
if beta == 0:
|
if beta == 0:
|
||||||
|
@ -5,19 +5,20 @@ import scipy.spatial
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||||
from stable_baselines3.common.distributions import sum_independent_dims
|
from stable_baselines3.common.distributions import sum_independent_dims
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal, MultivariateNormal
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from priorConditionedAnnealing import noise, kernel
|
from priorConditionedAnnealing import noise, kernel
|
||||||
|
from priorConditionedAnnealing.tensor_ops import fill_triangular, fill_triangular_inverse
|
||||||
|
|
||||||
|
|
||||||
class Par_Strength(Enum):
|
class Par_Strength(Enum):
|
||||||
SCALAR = 'SCALAR'
|
SCALAR = 'SCALAR'
|
||||||
DIAG = 'DIAG'
|
DIAG = 'DIAG'
|
||||||
|
FULL = 'FULL'
|
||||||
CONT_SCALAR = 'CONT_SCALAR'
|
CONT_SCALAR = 'CONT_SCALAR'
|
||||||
CONT_DIAG = 'CONT_DIAG'
|
CONT_DIAG = 'CONT_DIAG'
|
||||||
CONT_HYBRID = 'CONT_HYBRID'
|
CONT_HYBRID = 'CONT_HYBRID'
|
||||||
|
CONT_FULL = 'CONT_FULL'
|
||||||
|
|
||||||
class EnforcePositiveType(Enum):
|
class EnforcePositiveType(Enum):
|
||||||
# This need to be implemented in this ugly fashion,
|
# This need to be implemented in this ugly fashion,
|
||||||
@ -31,7 +32,7 @@ class EnforcePositiveType(Enum):
|
|||||||
|
|
||||||
def apply(self, x):
|
def apply(self, x):
|
||||||
# aaaaaa
|
# aaaaaa
|
||||||
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
return [nn.Identity(), nn.Softplus(beta=10, threshold=2), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
||||||
|
|
||||||
|
|
||||||
class Avaible_Kernel_Funcs(Enum):
|
class Avaible_Kernel_Funcs(Enum):
|
||||||
@ -98,6 +99,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
epsilon: float = 1e-6,
|
epsilon: float = 1e-6,
|
||||||
skip_conditioning: bool = False,
|
skip_conditioning: bool = False,
|
||||||
temporal_gradient_emission: bool = False,
|
temporal_gradient_emission: bool = False,
|
||||||
|
msqrt_induces_full: bool = False,
|
||||||
Base_Noise=noise.White_Noise,
|
Base_Noise=noise.White_Noise,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -111,9 +113,12 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.skip_conditioning = skip_conditioning
|
self.skip_conditioning = skip_conditioning
|
||||||
self.temporal_gradient_emission = temporal_gradient_emission
|
self.temporal_gradient_emission = temporal_gradient_emission
|
||||||
|
self.msqrt_induces_full = msqrt_induces_full
|
||||||
|
|
||||||
self.base_noise = cast_to_Noise(Base_Noise, (n_envs, action_dim))
|
self.base_noise = cast_to_Noise(Base_Noise, (n_envs, action_dim))
|
||||||
|
|
||||||
|
assert not (not skip_conditioning and self.is_full()), 'Conditioning full Covariances not yet implemented'
|
||||||
|
|
||||||
# Premature optimization is the root of all evil
|
# Premature optimization is the root of all evil
|
||||||
self._build_conditioner()
|
self._build_conditioner()
|
||||||
# *Optimizes it anyways*
|
# *Optimizes it anyways*
|
||||||
@ -126,8 +131,13 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
|
|
||||||
def proba_distribution(
|
def proba_distribution(
|
||||||
self, mean_actions: th.Tensor, std_actions: th.Tensor) -> SB3_Distribution:
|
self, mean_actions: th.Tensor, std_actions: th.Tensor) -> SB3_Distribution:
|
||||||
self.distribution = Normal(
|
if self.is_full():
|
||||||
mean_actions, std_actions)
|
self.distribution = MultivariateNormal(mean_actions, scale_tril=std_actions, validate_args=not self.msqrt_induces_full)
|
||||||
|
#self.distribution.scale = th.diagonal(std_actions, dim1=-2, dim2=-1)
|
||||||
|
self.distribution._mark_mSqrt = self.msqrt_induces_full
|
||||||
|
else:
|
||||||
|
self.distribution = Normal(
|
||||||
|
mean_actions, std_actions)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||||
@ -156,24 +166,29 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return self.mode()
|
return self.mode()
|
||||||
return self.sample(traj=trajectory)
|
return self.sample(traj=trajectory)
|
||||||
|
|
||||||
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
def sample(self, traj: th.Tensor, f_sigma: float = 1.0, epsilon=None) -> th.Tensor:
|
||||||
assert self.skip_conditioning or type(traj) != type(None), 'A past trajectory has to be supplied if conditinoning is performed'
|
assert self.skip_conditioning or type(traj) != type(None), 'A past trajectory has to be supplied if conditinoning is performed'
|
||||||
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
pi_mean, pi_decomp = self.distribution.mean.cpu(), self.distribution.scale_tril.cpu() if self.is_full() else self.distribution.scale.cpu()
|
||||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_decomp)
|
||||||
rho_std = rho_std * f_sigma
|
rho_std = rho_std * f_sigma
|
||||||
eta = self._get_rigged(pi_mean, pi_std,
|
eta = self._get_rigged(pi_mean, pi_decomp,
|
||||||
rho_mean, rho_std,
|
rho_mean, rho_std,
|
||||||
epsilon)
|
epsilon)
|
||||||
# reparameterization with rigged samples
|
# reparameterization with rigged samples
|
||||||
actions = pi_mean + pi_std * eta
|
if self.is_full():
|
||||||
|
actions = pi_mean + th.matmul(pi_decomp, eta.unsqueeze(-1)).squeeze(-1)
|
||||||
|
else:
|
||||||
|
actions = pi_mean + pi_decomp * eta
|
||||||
|
|
||||||
self.gaussian_actions = actions
|
self.gaussian_actions = actions
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def is_contextual(self):
|
def is_contextual(self):
|
||||||
return True # TODO: Remove, when bug for non-contextual is fixed
|
return self.par_strength in [Par_Strength.CONT_SCALAR, Par_Strength.CONT_DIAG, Par_Strength.CONT_HYBRID, Par_Strength.CONT_FULL]
|
||||||
# Always returning True will merely waste cpu cycles
|
|
||||||
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
def is_full(self):
|
||||||
|
return self.par_strength in [Par_Strength.FULL, Par_Strength.CONT_FULL]
|
||||||
|
|
||||||
|
|
||||||
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||||
# Ugly function to ensure, that the gradients flow as intended for each modus operandi
|
# Ugly function to ensure, that the gradients flow as intended for each modus operandi
|
||||||
@ -251,7 +266,8 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
# S_{ij} = \frac{1}{D_j} \left( A_{ij} - \sum_{k=1}^{j-1} S_{ik} S_{jk} D_k \right), \qquad\text{for } i>j
|
# S_{ij} = \frac{1}{D_j} \left( A_{ij} - \sum_{k=1}^{j-1} S_{ik} S_{jk} D_k \right), \qquad\text{for } i>j
|
||||||
# https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png
|
# https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png
|
||||||
# This way conditioning of the GP can be done in O(dim(A)) time.
|
# This way conditioning of the GP can be done in O(dim(A)) time.
|
||||||
if not self.is_contextual():
|
if not self.is_contextual() and False:
|
||||||
|
# Always assuming contextual will merely waste cpu cycles
|
||||||
# TODO: fix, this does not work
|
# TODO: fix, this does not work
|
||||||
# safe inplace
|
# safe inplace
|
||||||
self.conditioner[-1, -
|
self.conditioner[-1, -
|
||||||
@ -289,7 +305,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
|
|
||||||
|
|
||||||
class StdNet(nn.Module):
|
class StdNet(nn.Module):
|
||||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std):
|
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.latent_dim = latent_dim
|
self.latent_dim = latent_dim
|
||||||
@ -299,8 +315,6 @@ class StdNet(nn.Module):
|
|||||||
|
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.return_log_std = return_log_std
|
self.return_log_std = return_log_std
|
||||||
if return_log_std:
|
|
||||||
self.enforce_positive_type = EnforcePositiveType.NONE
|
|
||||||
|
|
||||||
if self.par_strength == Par_Strength.SCALAR:
|
if self.par_strength == Par_Strength.SCALAR:
|
||||||
self.param = nn.Parameter(
|
self.param = nn.Parameter(
|
||||||
@ -308,6 +322,11 @@ class StdNet(nn.Module):
|
|||||||
elif self.par_strength == Par_Strength.DIAG:
|
elif self.par_strength == Par_Strength.DIAG:
|
||||||
self.param = nn.Parameter(
|
self.param = nn.Parameter(
|
||||||
th.Tensor(th.ones(action_dim)*std_init), requires_grad=True)
|
th.Tensor(th.ones(action_dim)*std_init), requires_grad=True)
|
||||||
|
elif self.par_strength == Par_Strength.FULL:
|
||||||
|
ident = th.eye(action_dim)*std_init
|
||||||
|
ident_chol = fill_triangular_inverse(ident)
|
||||||
|
self.param = nn.Parameter(
|
||||||
|
th.Tensor(ident_chol), requires_grad=True)
|
||||||
elif self.par_strength == Par_Strength.CONT_SCALAR:
|
elif self.par_strength == Par_Strength.CONT_SCALAR:
|
||||||
self.net = nn.Linear(latent_dim, 1)
|
self.net = nn.Linear(latent_dim, 1)
|
||||||
elif self.par_strength == Par_Strength.CONT_HYBRID:
|
elif self.par_strength == Par_Strength.CONT_HYBRID:
|
||||||
@ -316,6 +335,11 @@ class StdNet(nn.Module):
|
|||||||
th.Tensor(th.ones(action_dim)*std_init), requires_grad=True)
|
th.Tensor(th.ones(action_dim)*std_init), requires_grad=True)
|
||||||
elif self.par_strength == Par_Strength.CONT_DIAG:
|
elif self.par_strength == Par_Strength.CONT_DIAG:
|
||||||
self.net = nn.Linear(latent_dim, self.action_dim)
|
self.net = nn.Linear(latent_dim, self.action_dim)
|
||||||
|
self.bias = th.ones(action_dim)*self.std_init
|
||||||
|
elif self.par_strength == Par_Strength.CONT_FULL:
|
||||||
|
self.net = nn.Linear(latent_dim, action_dim * (action_dim + 1) // 2)
|
||||||
|
self.bias = fill_triangular_inverse(th.eye(action_dim)*self.std_init)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: th.Tensor) -> th.Tensor:
|
def forward(self, x: th.Tensor) -> th.Tensor:
|
||||||
if self.par_strength == Par_Strength.SCALAR:
|
if self.par_strength == Par_Strength.SCALAR:
|
||||||
@ -323,6 +347,8 @@ class StdNet(nn.Module):
|
|||||||
th.ones(self.action_dim) * self.param[0])
|
th.ones(self.action_dim) * self.param[0])
|
||||||
elif self.par_strength == Par_Strength.DIAG:
|
elif self.par_strength == Par_Strength.DIAG:
|
||||||
return self._ensure_positive_func(self.param)
|
return self._ensure_positive_func(self.param)
|
||||||
|
elif self.par_strength == Par_Strength.FULL:
|
||||||
|
return self._chol_from_flat(self.param)
|
||||||
elif self.par_strength == Par_Strength.CONT_SCALAR:
|
elif self.par_strength == Par_Strength.CONT_SCALAR:
|
||||||
cont = self.net(x)
|
cont = self.net(x)
|
||||||
diag_chol = th.ones(self.action_dim, device=cont.device) * cont * self.std_init
|
diag_chol = th.ones(self.action_dim, device=cont.device) * cont * self.std_init
|
||||||
@ -332,14 +358,27 @@ class StdNet(nn.Module):
|
|||||||
return self._ensure_positive_func(self.param * cont)
|
return self._ensure_positive_func(self.param * cont)
|
||||||
elif self.par_strength == Par_Strength.CONT_DIAG:
|
elif self.par_strength == Par_Strength.CONT_DIAG:
|
||||||
cont = self.net(x)
|
cont = self.net(x)
|
||||||
diag_chol = cont * self.std_init
|
diag_chol = cont + self.bias
|
||||||
return self._ensure_positive_func(diag_chol)
|
return self._ensure_positive_func(diag_chol)
|
||||||
|
elif self.par_strength == Par_Strength.CONT_FULL:
|
||||||
|
cont = self.net(x)
|
||||||
|
return self._chol_from_flat(cont + self.bias)
|
||||||
|
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
def _ensure_positive_func(self, x):
|
def _ensure_positive_func(self, x):
|
||||||
return self.enforce_positive_type.apply(x) + self.epsilon
|
return self.enforce_positive_type.apply(x) + self.epsilon
|
||||||
|
|
||||||
|
def _chol_from_flat(self, flat_chol):
|
||||||
|
chol = fill_triangular(flat_chol)
|
||||||
|
return self._ensure_diagonal_positive(chol)
|
||||||
|
|
||||||
|
def _ensure_diagonal_positive(self, chol):
|
||||||
|
if len(chol.shape) == 1:
|
||||||
|
# If our chol is a vector (representing a diagonal chol)
|
||||||
|
return self._ensure_positive_func(chol)
|
||||||
|
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
|
||||||
|
dim2=-1)).diag_embed() + chol.triu(1)
|
||||||
def string(self):
|
def string(self):
|
||||||
return '<StdNet />'
|
return '<StdNet />'
|
||||||
|
|
||||||
|
150
priorConditionedAnnealing/tensor_ops.py
Normal file
150
priorConditionedAnnealing/tensor_ops.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import torch as th
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def fill_triangular(x, upper=False):
|
||||||
|
"""
|
||||||
|
The following function is derived from TensorFlow Probability
|
||||||
|
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L784
|
||||||
|
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
|
||||||
|
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
|
||||||
|
Creates a (batch of) triangular matrix from a vector of inputs.
|
||||||
|
Created matrix can be lower- or upper-triangular. (It is more efficient to
|
||||||
|
create the matrix as upper or lower, rather than transpose.)
|
||||||
|
Triangular matrix elements are filled in a clockwise spiral. See example,
|
||||||
|
below.
|
||||||
|
If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
|
||||||
|
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||||
|
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
fill_triangular([1, 2, 3, 4, 5, 6])
|
||||||
|
# ==> [[4, 0, 0],
|
||||||
|
# [6, 5, 0],
|
||||||
|
# [3, 2, 1]]
|
||||||
|
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
|
||||||
|
# ==> [[1, 2, 3],
|
||||||
|
# [0, 5, 6],
|
||||||
|
# [0, 0, 4]]
|
||||||
|
```
|
||||||
|
The key trick is to create an upper triangular matrix by concatenating `x`
|
||||||
|
and a tail of itself, then reshaping.
|
||||||
|
Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
|
||||||
|
from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
|
||||||
|
contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
|
||||||
|
(so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
|
||||||
|
the first (`n = 5`) elements removed and reversed:
|
||||||
|
```python
|
||||||
|
x = np.arange(15) + 1
|
||||||
|
xc = np.concatenate([x, x[5:][::-1]])
|
||||||
|
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
|
||||||
|
# 12, 11, 10, 9, 8, 7, 6])
|
||||||
|
# (We add one to the arange result to disambiguate the zeros below the
|
||||||
|
# diagonal of our upper-triangular matrix from the first entry in `x`.)
|
||||||
|
# Now, when reshapedlay this out as a matrix:
|
||||||
|
y = np.reshape(xc, [5, 5])
|
||||||
|
# ==> array([[ 1, 2, 3, 4, 5],
|
||||||
|
# [ 6, 7, 8, 9, 10],
|
||||||
|
# [11, 12, 13, 14, 15],
|
||||||
|
# [15, 14, 13, 12, 11],
|
||||||
|
# [10, 9, 8, 7, 6]])
|
||||||
|
# Finally, zero the elements below the diagonal:
|
||||||
|
y = np.triu(y, k=0)
|
||||||
|
# ==> array([[ 1, 2, 3, 4, 5],
|
||||||
|
# [ 0, 7, 8, 9, 10],
|
||||||
|
# [ 0, 0, 13, 14, 15],
|
||||||
|
# [ 0, 0, 0, 12, 11],
|
||||||
|
# [ 0, 0, 0, 0, 6]])
|
||||||
|
```
|
||||||
|
From this example we see that the resuting matrix is upper-triangular, and
|
||||||
|
contains all the entries of x, as desired. The rest is details:
|
||||||
|
- If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
|
||||||
|
`n / 2` rows and half of an additional row), but the whole scheme still
|
||||||
|
works.
|
||||||
|
- If we want a lower triangular matrix instead of an upper triangular,
|
||||||
|
we remove the first `n` elements from `x` rather than from the reversed
|
||||||
|
`x`.
|
||||||
|
For additional comparisons, a pure numpy version of this function can be found
|
||||||
|
in `distribution_util_test.py`, function `_fill_triangular`.
|
||||||
|
Args:
|
||||||
|
x: `Tensor` representing lower (or upper) triangular elements.
|
||||||
|
upper: Python `bool` representing whether output matrix should be upper
|
||||||
|
triangular (`True`) or lower triangular (`False`, default).
|
||||||
|
Returns:
|
||||||
|
tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x` cannot be mapped to a triangular matrix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
m = np.int32(x.shape[-1])
|
||||||
|
# Formula derived by solving for n: m = n(n+1)/2.
|
||||||
|
n = np.sqrt(0.25 + 2. * m) - 0.5
|
||||||
|
if n != np.floor(n):
|
||||||
|
raise ValueError('Input right-most shape ({}) does not '
|
||||||
|
'correspond to a triangular matrix.'.format(m))
|
||||||
|
n = np.int32(n)
|
||||||
|
new_shape = x.shape[:-1] + (n, n)
|
||||||
|
|
||||||
|
ndims = len(x.shape)
|
||||||
|
if upper:
|
||||||
|
x_list = [x, th.flip(x[..., n:], dims=[ndims - 1])]
|
||||||
|
else:
|
||||||
|
x_list = [x[..., n:], th.flip(x, dims=[ndims - 1])]
|
||||||
|
|
||||||
|
x = th.cat(x_list, dim=-1).reshape(new_shape)
|
||||||
|
x = th.triu(x) if upper else th.tril(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def fill_triangular_inverse(x, upper=False):
|
||||||
|
"""
|
||||||
|
The following function is derived from TensorFlow Probability
|
||||||
|
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L934
|
||||||
|
Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license,
|
||||||
|
cf. 3rd-party-licenses.txt file in the root directory of this source tree.
|
||||||
|
Creates a vector from a (batch of) triangular matrix.
|
||||||
|
The vector is created from the lower-triangular or upper-triangular portion
|
||||||
|
depending on the value of the parameter `upper`.
|
||||||
|
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
|
||||||
|
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
fill_triangular_inverse(
|
||||||
|
[[4, 0, 0],
|
||||||
|
[6, 5, 0],
|
||||||
|
[3, 2, 1]])
|
||||||
|
# ==> [1, 2, 3, 4, 5, 6]
|
||||||
|
fill_triangular_inverse(
|
||||||
|
[[1, 2, 3],
|
||||||
|
[0, 5, 6],
|
||||||
|
[0, 0, 4]], upper=True)
|
||||||
|
# ==> [1, 2, 3, 4, 5, 6]
|
||||||
|
```
|
||||||
|
Args:
|
||||||
|
x: `Tensor` representing lower (or upper) triangular elements.
|
||||||
|
upper: Python `bool` representing whether output matrix should be upper
|
||||||
|
triangular (`True`) or lower triangular (`False`, default).
|
||||||
|
Returns:
|
||||||
|
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
|
||||||
|
(or upper) triangular elements from `x`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n = np.int32(x.shape[-1])
|
||||||
|
m = np.int32((n * (n + 1)) // 2)
|
||||||
|
|
||||||
|
ndims = len(x.shape)
|
||||||
|
if upper:
|
||||||
|
initial_elements = x[..., 0, :]
|
||||||
|
triangular_part = x[..., 1:, :]
|
||||||
|
else:
|
||||||
|
initial_elements = th.flip(x[..., -1, :], dims=[ndims - 2])
|
||||||
|
triangular_part = x[..., :-1, :]
|
||||||
|
|
||||||
|
rotated_triangular_portion = th.flip(
|
||||||
|
th.flip(triangular_part, dims=[ndims - 1]), dims=[ndims - 2])
|
||||||
|
consolidated_matrix = triangular_part + rotated_triangular_portion
|
||||||
|
|
||||||
|
end_sequence = consolidated_matrix.reshape(x.shape[:-2] + (n * (n - 1),))
|
||||||
|
|
||||||
|
y = th.cat([initial_elements, end_sequence[..., :m - n]], dim=-1)
|
||||||
|
return y
|
Loading…
Reference in New Issue
Block a user