Fixed numerical issues with Wasserstein
This commit is contained in:
parent
9fffe048af
commit
a9e3f295b2
@ -232,8 +232,11 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
if nobatch:
|
||||
cov = th.mm(cov_sqrt.mT, cov_sqrt)
|
||||
cov += th.eye(cov.shape[-1])*(self.epsilon)
|
||||
else:
|
||||
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
||||
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(self.epsilon)
|
||||
|
||||
chol = th.linalg.cholesky(cov)
|
||||
|
||||
if vec:
|
||||
|
@ -124,6 +124,8 @@ def _sqrt_to_chol(cov_sqrt):
|
||||
cov_sqrt = th.diag_embed(cov_sqrt)
|
||||
|
||||
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
||||
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
|
||||
|
||||
chol = th.linalg.cholesky(cov)
|
||||
|
||||
if vec:
|
||||
|
Loading…
Reference in New Issue
Block a user