Fixed numerical issues with Wasserstein

This commit is contained in:
Dominik Moritz Roth 2022-08-17 23:25:24 +02:00
parent 9fffe048af
commit a9e3f295b2
2 changed files with 5 additions and 0 deletions

View File

@ -232,8 +232,11 @@ class UniversalGaussianDistribution(SB3_Distribution):
if nobatch: if nobatch:
cov = th.mm(cov_sqrt.mT, cov_sqrt) cov = th.mm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1])*(self.epsilon)
else: else:
cov = th.bmm(cov_sqrt.mT, cov_sqrt) cov = th.bmm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(self.epsilon)
chol = th.linalg.cholesky(cov) chol = th.linalg.cholesky(cov)
if vec: if vec:

View File

@ -124,6 +124,8 @@ def _sqrt_to_chol(cov_sqrt):
cov_sqrt = th.diag_embed(cov_sqrt) cov_sqrt = th.diag_embed(cov_sqrt)
cov = th.bmm(cov_sqrt.mT, cov_sqrt) cov = th.bmm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
chol = th.linalg.cholesky(cov) chol = th.linalg.cholesky(cov)
if vec: if vec: