Fixed bugs with givens-rotator for eigen
This commit is contained in:
parent
e31c25135e
commit
82a174122a
@ -557,11 +557,7 @@ class CholNet(nn.Module):
|
|||||||
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
||||||
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||||
if batch:
|
if batch:
|
||||||
# try:
|
|
||||||
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
|
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
|
||||||
# except:
|
|
||||||
# import pdb
|
|
||||||
# pdb.set_trace()
|
|
||||||
else:
|
else:
|
||||||
L[i, j] = S[i, 0] * t * maybe_cos
|
L[i, j] = S[i, 0] * t * maybe_cos
|
||||||
# print('[L_'+str(i+1)+']_'+str(j+1) +
|
# print('[L_'+str(i+1)+']_'+str(j+1) +
|
||||||
@ -586,7 +582,8 @@ class CholNet(nn.Module):
|
|||||||
dim2=-1)).diag_embed() + chol.triu(1)
|
dim2=-1)).diag_embed() + chol.triu(1)
|
||||||
|
|
||||||
def _chol_from_givens_params(self, params, bijection=False):
|
def _chol_from_givens_params(self, params, bijection=False):
|
||||||
theta, eigenv = params[:self.action_dim], params[self.action_dim:]
|
theta, eigenv = params[:,
|
||||||
|
:self.action_dim], params[:, self.action_dim:]
|
||||||
|
|
||||||
eigenv = self._ensure_positive_func(eigenv)
|
eigenv = self._ensure_positive_func(eigenv)
|
||||||
|
|
||||||
@ -594,8 +591,12 @@ class CholNet(nn.Module):
|
|||||||
eigenv = th.cumsum(eigenv, -1)
|
eigenv = th.cumsum(eigenv, -1)
|
||||||
# reverse order, oh well...
|
# reverse order, oh well...
|
||||||
|
|
||||||
self._givens_rotator.theta = theta
|
Q = th.zeros((theta.shape[0], self.action_dim,
|
||||||
Q = self._givens_rotator(self._givens_ident)
|
self.action_dim), device=eigenv.device)
|
||||||
|
for b in range(theta.shape[0]):
|
||||||
|
self._givens_rotator.theta = theta[b]
|
||||||
|
Q[b] = self._givens_rotator(self._givens_ident)
|
||||||
|
|
||||||
Qinv = Q.transpose(dim0=-2, dim1=-1)
|
Qinv = Q.transpose(dim0=-2, dim1=-1)
|
||||||
|
|
||||||
cov = Q * th.diag(eigenv) * Qinv
|
cov = Q * th.diag(eigenv) * Qinv
|
||||||
|
@ -59,8 +59,3 @@ class Rotation(nn.Module):
|
|||||||
for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))):
|
for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))):
|
||||||
x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx]))
|
x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx]))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user