From 41e4170b2f894f168ef50cb2b9beb69bb742d464 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 11 Jul 2022 17:28:08 +0200 Subject: [PATCH] Fixes + spherical_chol --- 3rd-party-licenses.txt | 2 + .../distributions/distributions.py | 169 +++++++++++------- metastable_baselines/misc/tensor_ops.py | 150 ++++++++++++++++ 3 files changed, 261 insertions(+), 60 deletions(-) create mode 100644 metastable_baselines/misc/tensor_ops.py diff --git a/3rd-party-licenses.txt b/3rd-party-licenses.txt index 565e9c6..5ae7b0a 100644 --- a/3rd-party-licenses.txt +++ b/3rd-party-licenses.txt @@ -10,6 +10,8 @@ such Third Party IP, are set forth below. Overview -------------------------------------------------------------------------- +# TODO: Tensorflow-Probability + # TODO: TrustRegionLayers boschresearch/trust-region-layers diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 1494424..1efd902 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -4,6 +4,7 @@ from enum import Enum import torch as th from torch import nn from torch.distributions import Normal, MultivariateNormal +from math import pi from stable_baselines3.common.preprocessing import get_action_dim @@ -13,6 +14,7 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution from ..misc.fakeModule import FakeModule from ..misc.distTools import new_dist_like +from ..misc.tensor_ops import fill_triangular # TODO: Integrate and Test what I currently have before adding more complexity # TODO: Support Squashed Dists (tanh) @@ -30,9 +32,8 @@ class Strength(Enum): class ParametrizationType(Enum): - # Currently only Chol is implemented CHOL = 1 - #SPHERICAL_CHOL = 2 + SPHERICAL_CHOL = 2 #GIVENS = 3 @@ -77,6 +78,9 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng if ps == Strength.SCALAR and cs == Strength.FULL: # TODO: Maybe allow? continue + if ps == Strength.DIAG and cs == Strength.FULL: + # TODO: Implement + continue if ps == Strength.NONE: yield (ps, cs, None, None) else: @@ -95,70 +99,72 @@ class UniversalGaussianDistribution(SB3_Distribution): :param action_dim: Dimension of the action space. """ - def __init__(self, action_dim: int): + def __init__(self, action_dim: int, neural_strength=Strength.DIAG, cov_strength=Strength.DIAG, parameterization_type=Strength.CHOL, enforce_positive_type=EnforcePositiveType.ABS, prob_squashing_type=ProbSquashingType.TANH): super(UniversalGaussianDistribution, self).__init__() - self.par_strength = Strength.DIAG - self.cov_strength = Strength.DIAG - self.par_type = ParametrizationType.CHOL - self.enforce_positive_type = EnforcePositiveType.LOG - self.prob_squashing_type = ProbSquashingType.TANH + self.par_strength = neural_strength + self.cov_strength = cov_strength + self.par_type = parameterization_type + self.enforce_positive_type = enforce_positive_type + self.prob_squashing_type = prob_squashing_type self.distribution = None + self._flat_chol_len = action_dim * (action_dim + 1) // 2 + def new_dist_like_me(self, mean, pseudo_chol): p = self.distribution np = new_dist_like(p, mean, pseudo_chol) - new = UniversalGaussianDistribution(self.action_dim) - new.par_strength = self.par_strength - new.cov_strength = self.cov_strength - new.par_type = self.par_type - new.enforce_positive_type = self.enforce_positive_type - new.prob_squashing_type = self.prob_squashing_type + new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength, + parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) new.distribution = np return new - def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: + def proba_distribution_net(self, latent_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the - standard deviation (log std in fact to allow negative values) + standard deviation :param latent_dim: Dimension of the last layer of the policy (before the action layer) - :param log_std_init: Initial value for the log standard deviation - :return: We return two nn.Modules (mean, chol). + :param std_init: Initial value for the standard deviation + :return: We return two nn.Modules (mean, chol). chol can be a vector if the full chol would be a diagonal. """ + assert std_init >= 0.0, "std can not be initialized to a negative value." + # TODO: Allow chol to be vector when only diagonal. mean_actions = nn.Linear(latent_dim, self.action_dim) if self.par_strength == Strength.NONE: if self.cov_strength == Strength.NONE: - pseudo_cov_par = th.ones(self.action_dim) * log_std_init + pseudo_cov_par = th.ones(self.action_dim) * std_init elif self.cov_strength == Strength.SCALAR: pseudo_cov_par = th.ones(self.action_dim) * \ - nn.Parameter(log_std_init, requires_grad=True) + nn.Parameter(std_init, requires_grad=True) + pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par) elif self.cov_strength == Strength.DIAG: pseudo_cov_par = nn.Parameter( - th.ones(self.action_dim) * log_std_init, requires_grad=True) + th.ones(self.action_dim) * std_init, requires_grad=True) + pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par) elif self.cov_strength == Strength.FULL: - # TODO: This won't work, need to ensure SPD! - # TODO: Off-axis init? - pseudo_cov_par = nn.Parameter( - th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True) + # TODO: Init Off-axis differently? + param = nn.Parameter( + th.ones(self._full_params_len) * std_init, requires_grad=True) + pseudo_cov_par = self._parameterize_full(param) chol = FakeModule(pseudo_cov_par) elif self.par_strength == self.cov_strength: - if self.par_strength == Strength.NONE: - chol = FakeModule(th.ones(self.action_dim)) - elif self.par_strength == Strength.SCALAR: - # TODO: Does it work like this? Test! + if self.par_strength == Strength.SCALAR: std = nn.Linear(latent_dim, 1) - chol = th.ones(self.action_dim) * std + diag_chol = th.ones(self.action_dim) * std + chol = self._ensure_positive_func(diag_chol) elif self.par_strength == Strength.DIAG: - chol = nn.Linear(latent_dim, self.action_dim) + diag_chol = nn.Linear(latent_dim, self.action_dim) + chol = self._ensure_positive_func(diag_chol) elif self.par_strength == Strength.FULL: - chol = self._parameterize_full(latent_dim) + params = nn.Linear(latent_dim, self._full_params_len) + chol = self._parameterize_full(params) elif self.par_strength > self.cov_strength: raise Exception( 'The parameterization can not be stronger than the actual covariance.') @@ -171,52 +177,95 @@ class UniversalGaussianDistribution(SB3_Distribution): raise Exception( 'That does not even make any sense...') else: - raise Exception("This Exception can't happen (I think)") + raise Exception("This Exception can't happen") return mean_actions, chol - def _parameterize_full(self, latent_dim): - # TODO: Implement various techniques for full parameterization (forcing SPD) - raise Exception( - 'Programmer-was-to-lazy-to-implement-this-Exception') + @property + def _full_params_len(self): + if self.par_type == ParametrizationType.CHOL: + return self._flat_chol_len + elif self.par_type == ParametrizationType.SPHERICAL_CHOL: + return self._flat_chol_len + raise Exception() - def _parameterize_hybrid_from_diag(self, latent_dim): + def _parameterize_full(self, params): + if self.par_type == ParametrizationType.CHOL: + return self._chol_from_flat(params) + elif self.par_type == ParametrizationType.SPHERICAL_CHOL: + return self._chol_from_flat_sphe_chol(params) + raise Exception() + + def _parameterize_hybrid_from_diag(self, params): # TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix) raise Exception( 'Programmer-was-to-lazy-to-implement-this-Exception') + def _parameterize_hybrid_from_scalar(self, latent_dim): + # SCALAR => DIAG + factor = nn.Linear(latent_dim, 1) + par = th.ones(self.action_dim) * \ + nn.Parameter(1, requires_grad=True) + diag_chol = self._ensure_positive_func(par * factor[0]) + return diag_chol + + def _chol_from_flat(self, flat_chol): + chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1) + return self._ensure_diagonal_positive(chol) + + def _chol_from_flat_sphe_chol(self, flat_sphe_chol): + pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol) + sphe_chol = fill_triangular(pos_flat_sphe_chol).expand( + self._flat_chol_len, -1, -1) + chol = self._chol_from_sphe_chol(sphe_chol) + return chol + + def _chol_from_sphe_chol(self, sphe_chol): + # TODO: Test with batched data + # TODO: Make efficient + # Note: + # We must should ensure: + # S[i,1] > 0 where i = 1..n + # S[i,j] e (0, pi) where i = 2..n, j = 2..i + # We already ensure S > 0 in _chol_from_flat_sphe_chol + # We ensure < pi by applying tanh*pi to all applicable elements + S = sphe_chol + n = self.action_dim + L = th.zeros_like(sphe_chol) + for i in range(n): + for j in range(i): + t = S[i, 1] + for k in range(1, j+1): + t *= th.sin(th.tanh(S[i, k])*pi) + if i != j: + t *= th.cos(th.tanh(S[i, j+1])*pi) + L[i, j] = t + return L + def _ensure_positive_func(self, x): return self.enforce_positive_type.apply(x) - def _ensure_diagonal_positive(self, pseudo_chol): - pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2, - dim2=-1)).diag_embed() + pseudo_chol.triu(1) + def _ensure_diagonal_positive(self, chol): + if len(chol.shape) == 1: + # If our chol is a vector (representing a diagonal chol) + return self._ensure_positive_func(chol) + return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, + dim2=-1)).diag_embed() + chol.triu(1) - def _parameterize_hybrid_from_scalar(self, latent_dim): - factor = nn.Linear(latent_dim, 1) - par_cov = th.ones(self.action_dim) * \ - nn.Parameter(1, requires_grad=True) - pseudo_cov = par_cov * factor[0] - return pseudo_cov - - def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution": + def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor) -> "UniversalGaussianDistribution": """ - Create the distribution given its parameters (mean, pseudo_cov) + Create the distribution given its parameters (mean, chol) :param mean_actions: - :param pseudo_cov: + :param chol: :return: """ - action_std = None - # TODO: Needs to be expanded - if self.cov_strength == Strength.DIAG: - if self.enforce_positive_type == EnforcePositiveType.LOG: - action_std = pseudo_cov.exp() - if action_std == None: - raise Exception('Not yet implemented!') - self.distribution = Normal(mean_actions, action_std) + if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]: + self.distribution = Normal(mean_actions, chol) + elif self.cov_strength in [Strength.FULL]: + self.distribution = MultivariateNormal(mean_actions, cholesky=chol) if self.distribution == None: - raise Exception('Not yet implemented!') + raise Exception('Unable to create torch distribution') return self def log_prob(self, actions: th.Tensor) -> th.Tensor: diff --git a/metastable_baselines/misc/tensor_ops.py b/metastable_baselines/misc/tensor_ops.py new file mode 100644 index 0000000..ee72dcc --- /dev/null +++ b/metastable_baselines/misc/tensor_ops.py @@ -0,0 +1,150 @@ +import torch as th +import numpy as np + + +def fill_triangular(x, upper=False): + """ + The following function is derived from TensorFlow Probability + https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L784 + Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license, + cf. 3rd-party-licenses.txt file in the root directory of this source tree. + Creates a (batch of) triangular matrix from a vector of inputs. + Created matrix can be lower- or upper-triangular. (It is more efficient to + create the matrix as upper or lower, rather than transpose.) + Triangular matrix elements are filled in a clockwise spiral. See example, + below. + If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is + `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., + `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. + Example: + ```python + fill_triangular([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + fill_triangular([1, 2, 3, 4, 5, 6], upper=True) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + ``` + The key trick is to create an upper triangular matrix by concatenating `x` + and a tail of itself, then reshaping. + Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M` + from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x` + contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5` + (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with + the first (`n = 5`) elements removed and reversed: + ```python + x = np.arange(15) + 1 + xc = np.concatenate([x, x[5:][::-1]]) + # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, + # 12, 11, 10, 9, 8, 7, 6]) + # (We add one to the arange result to disambiguate the zeros below the + # diagonal of our upper-triangular matrix from the first entry in `x`.) + # Now, when reshapedlay this out as a matrix: + y = np.reshape(xc, [5, 5]) + # ==> array([[ 1, 2, 3, 4, 5], + # [ 6, 7, 8, 9, 10], + # [11, 12, 13, 14, 15], + # [15, 14, 13, 12, 11], + # [10, 9, 8, 7, 6]]) + # Finally, zero the elements below the diagonal: + y = np.triu(y, k=0) + # ==> array([[ 1, 2, 3, 4, 5], + # [ 0, 7, 8, 9, 10], + # [ 0, 0, 13, 14, 15], + # [ 0, 0, 0, 12, 11], + # [ 0, 0, 0, 0, 6]]) + ``` + From this example we see that the resuting matrix is upper-triangular, and + contains all the entries of x, as desired. The rest is details: + - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills + `n / 2` rows and half of an additional row), but the whole scheme still + works. + - If we want a lower triangular matrix instead of an upper triangular, + we remove the first `n` elements from `x` rather than from the reversed + `x`. + For additional comparisons, a pure numpy version of this function can be found + in `distribution_util_test.py`, function `_fill_triangular`. + Args: + x: `Tensor` representing lower (or upper) triangular elements. + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + Returns: + tril: `Tensor` with lower (or upper) triangular elements filled from `x`. + Raises: + ValueError: if `x` cannot be mapped to a triangular matrix. + """ + + m = np.int32(x.shape[-1]) + # Formula derived by solving for n: m = n(n+1)/2. + n = np.sqrt(0.25 + 2. * m) - 0.5 + if n != np.floor(n): + raise ValueError('Input right-most shape ({}) does not ' + 'correspond to a triangular matrix.'.format(m)) + n = np.int32(n) + new_shape = x.shape[:-1] + (n, n) + + ndims = len(x.shape) + if upper: + x_list = [x, th.flip(x[..., n:], dims=[ndims - 1])] + else: + x_list = [x[..., n:], th.flip(x, dims=[ndims - 1])] + + x = th.cat(x_list, dim=-1).reshape(new_shape) + x = th.triu(x) if upper else th.tril(x) + return x + + +def fill_triangular_inverse(x, upper=False): + """ + The following function is derived from TensorFlow Probability + https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L934 + Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license, + cf. 3rd-party-licenses.txt file in the root directory of this source tree. + Creates a vector from a (batch of) triangular matrix. + The vector is created from the lower-triangular or upper-triangular portion + depending on the value of the parameter `upper`. + If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is + `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. + Example: + ```python + fill_triangular_inverse( + [[4, 0, 0], + [6, 5, 0], + [3, 2, 1]]) + # ==> [1, 2, 3, 4, 5, 6] + fill_triangular_inverse( + [[1, 2, 3], + [0, 5, 6], + [0, 0, 4]], upper=True) + # ==> [1, 2, 3, 4, 5, 6] + ``` + Args: + x: `Tensor` representing lower (or upper) triangular elements. + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + Returns: + flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower + (or upper) triangular elements from `x`. + """ + + n = np.int32(x.shape[-1]) + m = np.int32((n * (n + 1)) // 2) + + ndims = len(x.shape) + if upper: + initial_elements = x[..., 0, :] + triangular_part = x[..., 1:, :] + else: + initial_elements = ch.flip(x[..., -1, :], dims=[ndims - 2]) + triangular_part = x[..., :-1, :] + + rotated_triangular_portion = ch.flip( + th.flip(triangular_part, dims=[ndims - 1]), dims=[ndims - 2]) + consolidated_matrix = triangular_part + rotated_triangular_portion + + end_sequence = consolidated_matrix.reshape(x.shape[:-2] + (n * (n - 1),)) + + y = th.cat([initial_elements, end_sequence[..., :m - n]], dim=-1) + return y