Fixed Spherical_Chol not accepting batches

This commit is contained in:
Dominik Moritz Roth 2022-08-17 22:55:42 +02:00
parent 86e6bfb65b
commit 9fffe048af

View File

@ -518,7 +518,6 @@ class CholNet(nn.Module):
return chol return chol
def _chol_from_sphe_chol(self, sphe_chol): def _chol_from_sphe_chol(self, sphe_chol):
# TODO: Test with batched data
# TODO: Make efficient more # TODO: Make efficient more
# Note: # Note:
# We must should ensure: # We must should ensure:
@ -527,16 +526,21 @@ class CholNet(nn.Module):
# We already ensure S > 0 in _chol_from_flat_sphe_chol # We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements # We ensure < pi by applying tanh*pi to all applicable elements
batch = (len(sphe_chol.shape) == 3) batch = (len(sphe_chol.shape) == 3)
batch_size = sphe_chol.shape[0]
S = sphe_chol S = sphe_chol
n = sphe_chol.shape[-1] n = sphe_chol.shape[-1]
L = th.zeros_like(sphe_chol) L = th.zeros_like(sphe_chol)
for i in range(n): for i in range(n):
#t = 1 #t = 1
t = th.Tensor([1])[0] t = th.Tensor([1])[0]
if batch:
t = t.expand((batch_size, 1))
#s = '' #s = ''
for j in range(i+1): for j in range(i+1):
#maybe_cos = 1 #maybe_cos = 1
maybe_cos = th.Tensor([1])[0] maybe_cos = th.Tensor([1])[0]
if batch:
maybe_cos = maybe_cos.expand((batch_size, 1))
#s_maybe_cos = '' #s_maybe_cos = ''
if i != j and j < n-1 and i < n: if i != j and j < n-1 and i < n:
if batch: if batch:
@ -545,14 +549,19 @@ class CholNet(nn.Module):
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' #s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
if batch: if batch:
L[:, i, j] = S[:, i, 0] * t * maybe_cos # try:
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
# except:
# import pdb
# pdb.set_trace()
else: else:
L[i, j] = S[i, 0] * t * maybe_cos L[i, j] = S[i, 0] * t * maybe_cos
# print('[L_'+str(i+1)+']_'+str(j+1) + # print('[L_'+str(i+1)+']_'+str(j+1) +
# '=[l_'+str(i+1)+']_1'+s+s_maybe_cos) # '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
if j <= i and j < n-1 and i < n: if j <= i and j < n-1 and i < n:
if batch: if batch:
t *= th.sin(th.tanh(S[:, i, j+1])*pi) tc = t.clone()
t = (tc.T * th.sin(th.tanh(S[:, i, j+1])*pi)).T
else: else:
t *= th.sin(th.tanh(S[i, j+1])*pi) t *= th.sin(th.tanh(S[i, j+1])*pi)
#s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')' #s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'