Fixed gradient throught spherical-chol

This commit is contained in:
Dominik Moritz Roth 2022-08-10 11:55:08 +02:00
parent 520dc98eb5
commit 0e4eedae5e

View File

@ -514,10 +514,12 @@ class CholNet(nn.Module):
n = sphe_chol.shape[-1] n = sphe_chol.shape[-1]
L = th.zeros_like(sphe_chol) L = th.zeros_like(sphe_chol)
for i in range(n): for i in range(n):
t = 1 #t = 1
t = th.Tensor([1])[0]
#s = '' #s = ''
for j in range(i+1): for j in range(i+1):
maybe_cos = 1 #maybe_cos = 1
maybe_cos = th.Tensor([1])[0]
#s_maybe_cos = '' #s_maybe_cos = ''
if i != j and j < n-1 and i < n: if i != j and j < n-1 and i < n:
if batch: if batch: