Fixed Spherical_Chol not accepting batches
This commit is contained in:
parent
86e6bfb65b
commit
9fffe048af
@ -518,7 +518,6 @@ class CholNet(nn.Module):
|
|||||||
return chol
|
return chol
|
||||||
|
|
||||||
def _chol_from_sphe_chol(self, sphe_chol):
|
def _chol_from_sphe_chol(self, sphe_chol):
|
||||||
# TODO: Test with batched data
|
|
||||||
# TODO: Make efficient more
|
# TODO: Make efficient more
|
||||||
# Note:
|
# Note:
|
||||||
# We must should ensure:
|
# We must should ensure:
|
||||||
@ -527,16 +526,21 @@ class CholNet(nn.Module):
|
|||||||
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||||
# We ensure < pi by applying tanh*pi to all applicable elements
|
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||||
batch = (len(sphe_chol.shape) == 3)
|
batch = (len(sphe_chol.shape) == 3)
|
||||||
|
batch_size = sphe_chol.shape[0]
|
||||||
S = sphe_chol
|
S = sphe_chol
|
||||||
n = sphe_chol.shape[-1]
|
n = sphe_chol.shape[-1]
|
||||||
L = th.zeros_like(sphe_chol)
|
L = th.zeros_like(sphe_chol)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
#t = 1
|
#t = 1
|
||||||
t = th.Tensor([1])[0]
|
t = th.Tensor([1])[0]
|
||||||
|
if batch:
|
||||||
|
t = t.expand((batch_size, 1))
|
||||||
#s = ''
|
#s = ''
|
||||||
for j in range(i+1):
|
for j in range(i+1):
|
||||||
#maybe_cos = 1
|
#maybe_cos = 1
|
||||||
maybe_cos = th.Tensor([1])[0]
|
maybe_cos = th.Tensor([1])[0]
|
||||||
|
if batch:
|
||||||
|
maybe_cos = maybe_cos.expand((batch_size, 1))
|
||||||
#s_maybe_cos = ''
|
#s_maybe_cos = ''
|
||||||
if i != j and j < n-1 and i < n:
|
if i != j and j < n-1 and i < n:
|
||||||
if batch:
|
if batch:
|
||||||
@ -545,14 +549,19 @@ class CholNet(nn.Module):
|
|||||||
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
||||||
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||||
if batch:
|
if batch:
|
||||||
L[:, i, j] = S[:, i, 0] * t * maybe_cos
|
# try:
|
||||||
|
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
|
||||||
|
# except:
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
else:
|
else:
|
||||||
L[i, j] = S[i, 0] * t * maybe_cos
|
L[i, j] = S[i, 0] * t * maybe_cos
|
||||||
# print('[L_'+str(i+1)+']_'+str(j+1) +
|
# print('[L_'+str(i+1)+']_'+str(j+1) +
|
||||||
# '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
|
# '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
|
||||||
if j <= i and j < n-1 and i < n:
|
if j <= i and j < n-1 and i < n:
|
||||||
if batch:
|
if batch:
|
||||||
t *= th.sin(th.tanh(S[:, i, j+1])*pi)
|
tc = t.clone()
|
||||||
|
t = (tc.T * th.sin(th.tanh(S[:, i, j+1])*pi)).T
|
||||||
else:
|
else:
|
||||||
t *= th.sin(th.tanh(S[i, j+1])*pi)
|
t *= th.sin(th.tanh(S[i, j+1])*pi)
|
||||||
#s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
|
#s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||||
|
Loading…
Reference in New Issue
Block a user