From 4a24381f46bd358bfa28170edcb75e771819e17f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 16 Jul 2022 15:17:48 +0200 Subject: [PATCH] Fixed bug when using batches with SPHERICAL_CHOL --- .../distributions/distributions.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index d4742fb..9d63bc8 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -339,8 +339,18 @@ class CholNet(nn.Module): smol = self._parameterize_full(self.params) big = self.padder(smol) pearson_cor_chol = big + th.eye(stds.shape[-1]) - pearson_cor = pearson_cor_chol.T @ pearson_cor_chol - cov = stds.T * pearson_cor * stds + pearson_cor = (pearson_cor_chol.T @ + pearson_cor_chol) + if len(stds.shape) > 1: + # batched operation, we need to expand + pearson_cor = pearson_cor.expand( + (stds.shape[0],)+pearson_cor.shape) + stds = stds.unsqueeze(2) + try: + cov = stds.mT * pearson_cor * stds + except: + import pdb + pdb.set_trace() chol = th.linalg.cholesky(cov) return chol elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: