diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 924197b..dd8d1aa 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -232,8 +232,11 @@ class UniversalGaussianDistribution(SB3_Distribution): if nobatch: cov = th.mm(cov_sqrt.mT, cov_sqrt) + cov += th.eye(cov.shape[-1])*(self.epsilon) else: cov = th.bmm(cov_sqrt.mT, cov_sqrt) + cov += th.eye(cov.shape[-1]).expand(cov.shape)*(self.epsilon) + chol = th.linalg.cholesky(cov) if vec: diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index ef0eee0..da8b6ce 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -124,6 +124,8 @@ def _sqrt_to_chol(cov_sqrt): cov_sqrt = th.diag_embed(cov_sqrt) cov = th.bmm(cov_sqrt.mT, cov_sqrt) + cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) + chol = th.linalg.cholesky(cov) if vec: