Fixed: _chol_from_sphe_chol was unable to handle batches
This commit is contained in:
parent
c141599662
commit
046fa78206
@ -388,6 +388,7 @@ class CholNet(nn.Module):
|
||||
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
||||
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||
batch = (len(sphe_chol.shape) == 3)
|
||||
S = sphe_chol
|
||||
n = sphe_chol.shape[-1]
|
||||
L = th.zeros_like(sphe_chol)
|
||||
@ -397,13 +398,22 @@ class CholNet(nn.Module):
|
||||
for j in range(i+1):
|
||||
maybe_cos = 1
|
||||
#s_maybe_cos = ''
|
||||
if i != j:
|
||||
if i != j and j < n-1 and i < n:
|
||||
if batch:
|
||||
maybe_cos = th.cos(th.tanh(S[:, i, j+1])*pi)
|
||||
else:
|
||||
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
||||
s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||
if batch:
|
||||
L[:, i, j] = S[:, i, 0] * t * maybe_cos
|
||||
else:
|
||||
L[i, j] = S[i, 0] * t * maybe_cos
|
||||
# print('[L_'+str(i+1)+']_'+str(j+1) +
|
||||
# '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
|
||||
if j <= i and j < n-1 and i < n:
|
||||
if batch:
|
||||
t *= th.sin(th.tanh(S[:, i, j+1])*pi)
|
||||
else:
|
||||
t *= th.sin(th.tanh(S[i, j+1])*pi)
|
||||
#s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||
return L
|
||||
|
Loading…
Reference in New Issue
Block a user