From 82a174122aba92e8b046616d7e752e678af102b7 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 24 Oct 2022 10:08:31 +0200 Subject: [PATCH] Fixed bugs with givens-rotator for eigen --- .../distributions/distributions.py | 15 ++++++++------- metastable_baselines/misc/givens.py | 5 ----- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 7667cab..6c40e00 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -557,11 +557,7 @@ class CholNet(nn.Module): maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) #s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' if batch: - # try: L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T - # except: - # import pdb - # pdb.set_trace() else: L[i, j] = S[i, 0] * t * maybe_cos # print('[L_'+str(i+1)+']_'+str(j+1) + @@ -586,7 +582,8 @@ class CholNet(nn.Module): dim2=-1)).diag_embed() + chol.triu(1) def _chol_from_givens_params(self, params, bijection=False): - theta, eigenv = params[:self.action_dim], params[self.action_dim:] + theta, eigenv = params[:, + :self.action_dim], params[:, self.action_dim:] eigenv = self._ensure_positive_func(eigenv) @@ -594,8 +591,12 @@ class CholNet(nn.Module): eigenv = th.cumsum(eigenv, -1) # reverse order, oh well... - self._givens_rotator.theta = theta - Q = self._givens_rotator(self._givens_ident) + Q = th.zeros((theta.shape[0], self.action_dim, + self.action_dim), device=eigenv.device) + for b in range(theta.shape[0]): + self._givens_rotator.theta = theta[b] + Q[b] = self._givens_rotator(self._givens_ident) + Qinv = Q.transpose(dim0=-2, dim1=-1) cov = Q * th.diag(eigenv) * Qinv diff --git a/metastable_baselines/misc/givens.py b/metastable_baselines/misc/givens.py index 9ca3ee1..b392597 100644 --- a/metastable_baselines/misc/givens.py +++ b/metastable_baselines/misc/givens.py @@ -59,8 +59,3 @@ class Rotation(nn.Module): for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))): x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx])) return x - - -if __name__ == '__main__': - import doctest - doctest.testmod()