Fixed bug with Wasserstein
This commit is contained in:
parent
c7ca326345
commit
05b048fd0e
@ -111,7 +111,7 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor):
|
def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor):
|
||||||
chol = _sqrt_to_chol(cov_sqrt)
|
chol = _sqrt_to_chol(cov_sqrt, only_diag=has_diag_cov(orig_p))
|
||||||
|
|
||||||
new = new_dist_like(orig_p, mean, chol)
|
new = new_dist_like(orig_p, mean, chol)
|
||||||
|
|
||||||
@ -122,20 +122,13 @@ def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt:
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
def _sqrt_to_chol(cov_sqrt):
|
def _sqrt_to_chol(cov_sqrt, only_diag=False):
|
||||||
vec = False
|
|
||||||
if len(cov_sqrt.shape) == 2:
|
|
||||||
vec = True
|
|
||||||
|
|
||||||
if vec:
|
|
||||||
cov_sqrt = th.diag_embed(cov_sqrt)
|
|
||||||
|
|
||||||
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
||||||
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
|
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
|
||||||
|
|
||||||
chol = th.linalg.cholesky(cov)
|
chol = th.linalg.cholesky(cov)
|
||||||
|
|
||||||
if vec:
|
if only_diag:
|
||||||
chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
return chol
|
return chol
|
||||||
|
Loading…
Reference in New Issue
Block a user