Fixed bug when using batches with SPHERICAL_CHOL

This commit is contained in:
Dominik Moritz Roth 2022-07-16 15:17:48 +02:00
parent 04529e8261
commit 4a24381f46

View File

@ -339,8 +339,18 @@ class CholNet(nn.Module):
smol = self._parameterize_full(self.params) smol = self._parameterize_full(self.params)
big = self.padder(smol) big = self.padder(smol)
pearson_cor_chol = big + th.eye(stds.shape[-1]) pearson_cor_chol = big + th.eye(stds.shape[-1])
pearson_cor = pearson_cor_chol.T @ pearson_cor_chol pearson_cor = (pearson_cor_chol.T @
cov = stds.T * pearson_cor * stds pearson_cor_chol)
if len(stds.shape) > 1:
# batched operation, we need to expand
pearson_cor = pearson_cor.expand(
(stds.shape[0],)+pearson_cor.shape)
stds = stds.unsqueeze(2)
try:
cov = stds.mT * pearson_cor * stds
except:
import pdb
pdb.set_trace()
chol = th.linalg.cholesky(cov) chol = th.linalg.cholesky(cov)
return chol return chol
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: