Fixed bugs with givens-rotator for eigen

This commit is contained in:
Dominik Moritz Roth 2022-10-24 10:08:31 +02:00
parent e31c25135e
commit 82a174122a
2 changed files with 8 additions and 12 deletions

View File

@ -557,11 +557,7 @@ class CholNet(nn.Module):
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' #s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
if batch: if batch:
# try:
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
# except:
# import pdb
# pdb.set_trace()
else: else:
L[i, j] = S[i, 0] * t * maybe_cos L[i, j] = S[i, 0] * t * maybe_cos
# print('[L_'+str(i+1)+']_'+str(j+1) + # print('[L_'+str(i+1)+']_'+str(j+1) +
@ -586,7 +582,8 @@ class CholNet(nn.Module):
dim2=-1)).diag_embed() + chol.triu(1) dim2=-1)).diag_embed() + chol.triu(1)
def _chol_from_givens_params(self, params, bijection=False): def _chol_from_givens_params(self, params, bijection=False):
theta, eigenv = params[:self.action_dim], params[self.action_dim:] theta, eigenv = params[:,
:self.action_dim], params[:, self.action_dim:]
eigenv = self._ensure_positive_func(eigenv) eigenv = self._ensure_positive_func(eigenv)
@ -594,8 +591,12 @@ class CholNet(nn.Module):
eigenv = th.cumsum(eigenv, -1) eigenv = th.cumsum(eigenv, -1)
# reverse order, oh well... # reverse order, oh well...
self._givens_rotator.theta = theta Q = th.zeros((theta.shape[0], self.action_dim,
Q = self._givens_rotator(self._givens_ident) self.action_dim), device=eigenv.device)
for b in range(theta.shape[0]):
self._givens_rotator.theta = theta[b]
Q[b] = self._givens_rotator(self._givens_ident)
Qinv = Q.transpose(dim0=-2, dim1=-1) Qinv = Q.transpose(dim0=-2, dim1=-1)
cov = Q * th.diag(eigenv) * Qinv cov = Q * th.diag(eigenv) * Qinv

View File

@ -59,8 +59,3 @@ class Rotation(nn.Module):
for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))): for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))):
x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx])) x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx]))
return x return x
if __name__ == '__main__':
import doctest
doctest.testmod()