From 6441bbfc5bb16b3b60b7993f82702733f5c3ecc2 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 1 Apr 2024 00:03:10 +0200 Subject: [PATCH] Fix issues with full cov --- metastable_projections/misc/distTools.py | 47 ++---------------------- 1 file changed, 3 insertions(+), 44 deletions(-) diff --git a/metastable_projections/misc/distTools.py b/metastable_projections/misc/distTools.py index 0c706e3..d8613af 100644 --- a/metastable_projections/misc/distTools.py +++ b/metastable_projections/misc/distTools.py @@ -27,23 +27,8 @@ def get_mean_and_chol(p: AnyDistribution, expand=False): def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): - mean, chol = get_mean_and_chol(p, expand=False) - if not hasattr(p, 'cov_sqrt'): - if len(p.distribution.scale.shape)==2: # In the factorized case, chol = matSqrt - sqrt_cov = p.distribution.scale - else: - raise Exception( - 'Distribution was not induced from sqrt. On-demand calculation is not supported.') - else: - sqrt_cov = p.cov_sqrt - if mean.shape[0] != sqrt_cov.shape[0]: - shape = list(sqrt_cov.shape) - shape[0] = mean.shape[0] - shape = tuple(shape) - sqrt_cov = sqrt_cov.expand(shape) - if expand and len(sqrt_cov.shape) <= 2: - sqrt_cov = th.diag_embed(sqrt_cov) - return mean, sqrt_cov + mean, chol = get_mean_and_chol(p, expand=expand) + return mean, chol def get_cov(p: AnyDistribution): @@ -110,30 +95,4 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): raise Exception('Dist-Type not implemented (of sb3 dist)') return p_out else: - raise Exception('Dist-Type not implemented') - - -def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor): - chol = _sqrt_to_chol(cov_sqrt, only_diag=has_diag_cov(orig_p)) - - new = new_dist_like(orig_p, mean, chol) - - new.cov_sqrt = cov_sqrt - if hasattr(new, 'distribution'): - new.distribution.cov_sqrt = cov_sqrt - - return new - - -def _sqrt_to_chol(cov_sqrt, only_diag=False): - if only_diag: - return cov_sqrt - cov = th.bmm(cov_sqrt.mT, cov_sqrt) - cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) - - chol = th.linalg.cholesky(cov) - - #if only_diag: - # chol = th.diagonal(chol, dim1=-2, dim2=-1) - - return chol + raise Exception('Dist-Type not implemented') \ No newline at end of file