From a9e3f295b29c5334f531482ffa9c7f171b872273 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 17 Aug 2022 23:25:24 +0200 Subject: [PATCH] Fixed numerical issues with Wasserstein --- metastable_baselines/distributions/distributions.py | 3 +++ metastable_baselines/misc/distTools.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 924197b..dd8d1aa 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -232,8 +232,11 @@ class UniversalGaussianDistribution(SB3_Distribution): if nobatch: cov = th.mm(cov_sqrt.mT, cov_sqrt) + cov += th.eye(cov.shape[-1])*(self.epsilon) else: cov = th.bmm(cov_sqrt.mT, cov_sqrt) + cov += th.eye(cov.shape[-1]).expand(cov.shape)*(self.epsilon) + chol = th.linalg.cholesky(cov) if vec: diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index ef0eee0..da8b6ce 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -124,6 +124,8 @@ def _sqrt_to_chol(cov_sqrt): cov_sqrt = th.diag_embed(cov_sqrt) cov = th.bmm(cov_sqrt.mT, cov_sqrt) + cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) + chol = th.linalg.cholesky(cov) if vec: