Fixed: Wrong simplification for Hybrid[SCALAR=>FULL]

This commit is contained in:
Dominik Moritz Roth 2022-07-17 00:47:47 +02:00
parent 046fa78206
commit 49f9acff3e

View File

@ -350,8 +350,15 @@ class CholNet(nn.Module):
chol = th.linalg.cholesky(cov)
return chol
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
factor = self.factor(x)
return self._parameterize_full(self.params * factor[0])
# TODO: Maybe possible to improve speed and stability by multiplying with factor in cholesky-form.
factor = self._ensure_positive_func(self.factor(x))
par_chol = self._parameterize_full(self.params)
cov = (par_chol.T @ par_chol)
if len(factor) > 1:
factor = factor.unsqueeze(2)
cov = cov * factor
chol = th.linalg.cholesky(cov)
return chol
raise Exception()
@property