Fixed bug with Wasserstein
This commit is contained in:
parent
c7ca326345
commit
05b048fd0e
@ -111,7 +111,7 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
||||
|
||||
|
||||
def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor):
|
||||
chol = _sqrt_to_chol(cov_sqrt)
|
||||
chol = _sqrt_to_chol(cov_sqrt, only_diag=has_diag_cov(orig_p))
|
||||
|
||||
new = new_dist_like(orig_p, mean, chol)
|
||||
|
||||
@ -122,20 +122,13 @@ def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt:
|
||||
return new
|
||||
|
||||
|
||||
def _sqrt_to_chol(cov_sqrt):
|
||||
vec = False
|
||||
if len(cov_sqrt.shape) == 2:
|
||||
vec = True
|
||||
|
||||
if vec:
|
||||
cov_sqrt = th.diag_embed(cov_sqrt)
|
||||
|
||||
def _sqrt_to_chol(cov_sqrt, only_diag=False):
|
||||
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
||||
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
|
||||
|
||||
chol = th.linalg.cholesky(cov)
|
||||
|
||||
if vec:
|
||||
if only_diag:
|
||||
chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
||||
|
||||
return chol
|
||||
|
Loading…
Reference in New Issue
Block a user