diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 1fd530e..879bad2 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -350,8 +350,15 @@ class CholNet(nn.Module): chol = th.linalg.cholesky(cov) return chol elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: - factor = self.factor(x) - return self._parameterize_full(self.params * factor[0]) + # TODO: Maybe possible to improve speed and stability by multiplying with factor in cholesky-form. + factor = self._ensure_positive_func(self.factor(x)) + par_chol = self._parameterize_full(self.params) + cov = (par_chol.T @ par_chol) + if len(factor) > 1: + factor = factor.unsqueeze(2) + cov = cov * factor + chol = th.linalg.cholesky(cov) + return chol raise Exception() @property