Fixed bugs with givens-rotator for eigen
This commit is contained in:
parent
e31c25135e
commit
82a174122a
@ -557,11 +557,7 @@ class CholNet(nn.Module):
|
||||
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
||||
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||
if batch:
|
||||
# try:
|
||||
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
|
||||
# except:
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
else:
|
||||
L[i, j] = S[i, 0] * t * maybe_cos
|
||||
# print('[L_'+str(i+1)+']_'+str(j+1) +
|
||||
@ -586,7 +582,8 @@ class CholNet(nn.Module):
|
||||
dim2=-1)).diag_embed() + chol.triu(1)
|
||||
|
||||
def _chol_from_givens_params(self, params, bijection=False):
|
||||
theta, eigenv = params[:self.action_dim], params[self.action_dim:]
|
||||
theta, eigenv = params[:,
|
||||
:self.action_dim], params[:, self.action_dim:]
|
||||
|
||||
eigenv = self._ensure_positive_func(eigenv)
|
||||
|
||||
@ -594,8 +591,12 @@ class CholNet(nn.Module):
|
||||
eigenv = th.cumsum(eigenv, -1)
|
||||
# reverse order, oh well...
|
||||
|
||||
self._givens_rotator.theta = theta
|
||||
Q = self._givens_rotator(self._givens_ident)
|
||||
Q = th.zeros((theta.shape[0], self.action_dim,
|
||||
self.action_dim), device=eigenv.device)
|
||||
for b in range(theta.shape[0]):
|
||||
self._givens_rotator.theta = theta[b]
|
||||
Q[b] = self._givens_rotator(self._givens_ident)
|
||||
|
||||
Qinv = Q.transpose(dim0=-2, dim1=-1)
|
||||
|
||||
cov = Q * th.diag(eigenv) * Qinv
|
||||
|
@ -59,8 +59,3 @@ class Rotation(nn.Module):
|
||||
for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))):
|
||||
x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx]))
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
Loading…
Reference in New Issue
Block a user