From 0a037decccce1bf4a105b614c3cc366b8e4143d1 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 11:16:41 +0200 Subject: [PATCH] Implemented cov parametrization via eigen-decomp --- .../distributions/distributions.py | 36 +++++++++- metastable_baselines/misc/givens.py | 66 +++++++++++++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 metastable_baselines/misc/givens.py diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 5fbac7a..75336f9 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -7,6 +7,8 @@ from torch import nn from torch.distributions import Normal, Independent, MultivariateNormal from math import pi +import givens + from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.distributions import sum_independent_dims @@ -34,9 +36,8 @@ class ParametrizationType(Enum): NONE = 0 CHOL = 1 SPHERICAL_CHOL = 2 - # Not (yet?) implemented: - # GIVENS = 3 - # NNLN_EIGEN = 4 + EIGEN = 3 + EIGEN_RAW = 4 class EnforcePositiveType(Enum): @@ -382,6 +383,9 @@ class CholNet(nn.Module): self._flat_chol_len = action_dim * (action_dim + 1) // 2 + self._givens_rotator = givens.Rotation(action_dim) + self._givens_ident = th.eye(action_dim) + # Yes, this is ugly. # But I don't know how this mess could be elegantly abstracted away... @@ -493,6 +497,10 @@ class CholNet(nn.Module): return self._flat_chol_len elif self.par_type == ParametrizationType.SPHERICAL_CHOL: return self._flat_chol_len + elif self.par_type == ParametrizationType.EIGEN: + return self.action_dim * 2 + elif self.par_type == ParametrizationType.EIGEN_BIJECT: + return self.action_dim * 2 raise Exception() def _parameterize_full(self, params): @@ -500,6 +508,10 @@ class CholNet(nn.Module): return self._chol_from_flat(params) elif self.par_type == ParametrizationType.SPHERICAL_CHOL: return self._chol_from_flat_sphe_chol(params) + elif self.par_type == ParametrizationType.EIGEN: + return self._chol_from_givens_params(params, True) + elif self.par_type == ParametrizationType.EIGEN_RAW: + return self._chol_from_givens_params(params, False) raise Exception() def _chol_from_flat(self, flat_chol): @@ -576,6 +588,24 @@ class CholNet(nn.Module): return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, dim2=-1)).diag_embed() + chol.triu(1) + def _chol_from_givens_params(self, params, bijection=False): + theta, eigenv = params[:self.action_dim], params[self.action_dim:] + + eigenv = self._ensure_positive_func(eigenv) + + if bijection: + eigenv = th.cumsum(eigenv, -1) + # reverse order, oh well... + + self._givens_rot.theta = theta + Q = self._givens_rotator(self._givens_ident) + Qinv = Q.transpose(dim0=-2, dim1=-1) + + cov = Q * th.diag(eigenv) * Qinv + chol = th.linalg.cholesky(cov) + + return chol + def string(self): return '' diff --git a/metastable_baselines/misc/givens.py b/metastable_baselines/misc/givens.py new file mode 100644 index 0000000..9ca3ee1 --- /dev/null +++ b/metastable_baselines/misc/givens.py @@ -0,0 +1,66 @@ +# Source : https://github.com/diadochos/givens-torch +# TODO: License + +import itertools +import torch +import torch.nn as nn + + +def G_transpose(D, i, j, theta): + """Generate Givens rotation matrix. + >>> G_transpose(2, 0, 1, torch.FloatTensor([[3.1415 / 2]])) + tensor([[ 4.6329e-05, 1.0000e+00], + [-1.0000e+00, 4.6329e-05]]) + """ + R = torch.eye(D) + s, c = torch.sin(theta), torch.cos(theta) + R[i, i] = c + R[j, j] = c + R[i, j] = s + R[j, i] = -s + return R + + +class Rotation(nn.Module): + def __init__(self, D): + """ + >>> # Initialized as an identity. + >>> A, R = torch.eye(3), Rotation(3) + >>> torch.all(A.eq(R(A))).item() + True + """ + super().__init__() + self.D = D + self.theta = torch.zeros( + (len(list(itertools.combinations(range(self.D), 2))), )) + + def forward(self, x): + """Apply rotation. + >>> A, R = torch.eye(3), Rotation(3) + >>> R.theta = torch.FloatTensor([3.1415 / 2, 0., 0.]) + >>> R(A) + tensor([[ 4.6329e-05, 1.0000e+00, 0.0000e+00], + [-1.0000e+00, 4.6329e-05, 0.0000e+00], + [ 0.0000e+00, 0.0000e+00, 1.0000e+00]]) + """ + for idx, (i, j) in enumerate(itertools.combinations(range(self.D), 2)): + x = torch.matmul(x, G_transpose(self.D, i, j, self.theta[idx])) + return x + + def reverse(self, x): + """Apply reverse rotation. + >>> A, R = torch.eye(3), Rotation(3) + >>> R.weight = torch.FloatTensor([1., 2., 3.]) + >>> torch.any(A.eq(R(A))).item() + True + >>> torch.all(A.eq(R.reverse(R(A)))).item() + True + """ + for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))): + x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx])) + return x + + +if __name__ == '__main__': + import doctest + doctest.testmod()