diff --git a/metastable_projections/misc/distTools.py b/metastable_projections/misc/distTools.py index 67c21b5..5904239 100644 --- a/metastable_projections/misc/distTools.py +++ b/metastable_projections/misc/distTools.py @@ -111,7 +111,7 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor): - chol = _sqrt_to_chol(cov_sqrt) + chol = _sqrt_to_chol(cov_sqrt, only_diag=has_diag_cov(orig_p)) new = new_dist_like(orig_p, mean, chol) @@ -122,20 +122,13 @@ def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: return new -def _sqrt_to_chol(cov_sqrt): - vec = False - if len(cov_sqrt.shape) == 2: - vec = True - - if vec: - cov_sqrt = th.diag_embed(cov_sqrt) - +def _sqrt_to_chol(cov_sqrt, only_diag=False): cov = th.bmm(cov_sqrt.mT, cov_sqrt) cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) chol = th.linalg.cholesky(cov) - if vec: + if only_diag: chol = th.diagonal(chol, dim1=-2, dim2=-1) return chol