diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index d4742fb..9d63bc8 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -339,8 +339,18 @@ class CholNet(nn.Module): smol = self._parameterize_full(self.params) big = self.padder(smol) pearson_cor_chol = big + th.eye(stds.shape[-1]) - pearson_cor = pearson_cor_chol.T @ pearson_cor_chol - cov = stds.T * pearson_cor * stds + pearson_cor = (pearson_cor_chol.T @ + pearson_cor_chol) + if len(stds.shape) > 1: + # batched operation, we need to expand + pearson_cor = pearson_cor.expand( + (stds.shape[0],)+pearson_cor.shape) + stds = stds.unsqueeze(2) + try: + cov = stds.mT * pearson_cor * stds + except: + import pdb + pdb.set_trace() chol = th.linalg.cholesky(cov) return chol elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: