Fixed bug when using batches with SPHERICAL_CHOL
This commit is contained in:
parent
04529e8261
commit
4a24381f46
@ -339,8 +339,18 @@ class CholNet(nn.Module):
|
|||||||
smol = self._parameterize_full(self.params)
|
smol = self._parameterize_full(self.params)
|
||||||
big = self.padder(smol)
|
big = self.padder(smol)
|
||||||
pearson_cor_chol = big + th.eye(stds.shape[-1])
|
pearson_cor_chol = big + th.eye(stds.shape[-1])
|
||||||
pearson_cor = pearson_cor_chol.T @ pearson_cor_chol
|
pearson_cor = (pearson_cor_chol.T @
|
||||||
cov = stds.T * pearson_cor * stds
|
pearson_cor_chol)
|
||||||
|
if len(stds.shape) > 1:
|
||||||
|
# batched operation, we need to expand
|
||||||
|
pearson_cor = pearson_cor.expand(
|
||||||
|
(stds.shape[0],)+pearson_cor.shape)
|
||||||
|
stds = stds.unsqueeze(2)
|
||||||
|
try:
|
||||||
|
cov = stds.mT * pearson_cor * stds
|
||||||
|
except:
|
||||||
|
import pdb
|
||||||
|
pdb.set_trace()
|
||||||
chol = th.linalg.cholesky(cov)
|
chol = th.linalg.cholesky(cov)
|
||||||
return chol
|
return chol
|
||||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||||
|
Loading…
Reference in New Issue
Block a user