Fixed bug with Wasserstein

This commit is contained in:
Dominik Moritz Roth 2022-09-03 11:59:52 +02:00
parent c7ca326345
commit 05b048fd0e

View File

@ -111,7 +111,7 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor): def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor):
chol = _sqrt_to_chol(cov_sqrt) chol = _sqrt_to_chol(cov_sqrt, only_diag=has_diag_cov(orig_p))
new = new_dist_like(orig_p, mean, chol) new = new_dist_like(orig_p, mean, chol)
@ -122,20 +122,13 @@ def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt:
return new return new
def _sqrt_to_chol(cov_sqrt): def _sqrt_to_chol(cov_sqrt, only_diag=False):
vec = False
if len(cov_sqrt.shape) == 2:
vec = True
if vec:
cov_sqrt = th.diag_embed(cov_sqrt)
cov = th.bmm(cov_sqrt.mT, cov_sqrt) cov = th.bmm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
chol = th.linalg.cholesky(cov) chol = th.linalg.cholesky(cov)
if vec: if only_diag:
chol = th.diagonal(chol, dim1=-2, dim2=-1) chol = th.diagonal(chol, dim1=-2, dim2=-1)
return chol return chol