diff --git a/metastable_projections/misc/distTools.py b/metastable_projections/misc/distTools.py index 5904239..0c706e3 100644 --- a/metastable_projections/misc/distTools.py +++ b/metastable_projections/misc/distTools.py @@ -27,20 +27,23 @@ def get_mean_and_chol(p: AnyDistribution, expand=False): def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): + mean, chol = get_mean_and_chol(p, expand=False) if not hasattr(p, 'cov_sqrt'): - raise Exception( - 'Distribution was not induced from sqrt. On-demand calculation is not supported.') + if len(p.distribution.scale.shape)==2: # In the factorized case, chol = matSqrt + sqrt_cov = p.distribution.scale + else: + raise Exception( + 'Distribution was not induced from sqrt. On-demand calculation is not supported.') else: - mean, chol = get_mean_and_chol(p, expand=False) sqrt_cov = p.cov_sqrt - if mean.shape[0] != sqrt_cov.shape[0]: - shape = list(sqrt_cov.shape) - shape[0] = mean.shape[0] - shape = tuple(shape) - sqrt_cov = sqrt_cov.expand(shape) - if expand and len(sqrt_cov.shape) <= 2: - sqrt_cov = th.diag_embed(sqrt_cov) - return mean, sqrt_cov + if mean.shape[0] != sqrt_cov.shape[0]: + shape = list(sqrt_cov.shape) + shape[0] = mean.shape[0] + shape = tuple(shape) + sqrt_cov = sqrt_cov.expand(shape) + if expand and len(sqrt_cov.shape) <= 2: + sqrt_cov = th.diag_embed(sqrt_cov) + return mean, sqrt_cov def get_cov(p: AnyDistribution): @@ -123,12 +126,14 @@ def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: def _sqrt_to_chol(cov_sqrt, only_diag=False): + if only_diag: + return cov_sqrt cov = th.bmm(cov_sqrt.mT, cov_sqrt) cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) chol = th.linalg.cholesky(cov) - if only_diag: - chol = th.diagonal(chol, dim1=-2, dim2=-1) + #if only_diag: + # chol = th.diagonal(chol, dim1=-2, dim2=-1) return chol diff --git a/metastable_projections/projections/base_projection_layer.py b/metastable_projections/projections/base_projection_layer.py index c7b9cb4..7426938 100644 --- a/metastable_projections/projections/base_projection_layer.py +++ b/metastable_projections/projections/base_projection_layer.py @@ -86,6 +86,7 @@ class BaseProjectionLayer(object): Returns: projected_dist, old_dist (from rollouts) """ + old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp) return self(dist, old_distribution, **kwargs), old_distribution diff --git a/metastable_projections/projections/frob_projection_layer.py b/metastable_projections/projections/frob_projection_layer.py index 68db743..2309ab2 100644 --- a/metastable_projections/projections/frob_projection_layer.py +++ b/metastable_projections/projections/frob_projection_layer.py @@ -4,7 +4,7 @@ from typing import Tuple from .base_projection_layer import BaseProjectionLayer, mean_projection from ..misc.norm import mahalanobis, frob_sq -from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like +from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like, has_diag_cov class FrobeniusProjectionLayer(BaseProjectionLayer): @@ -57,6 +57,9 @@ class FrobeniusProjectionLayer(BaseProjectionLayer): else: proj_chol = chol + if has_diag_cov(p): + proj_chol = th.diagonal(proj_chol, dim1=-2, dim2=-1) + proj_p = new_dist_like(p, proj_mean, proj_chol) return proj_p diff --git a/metastable_projections/projections/w2_projection_layer.py b/metastable_projections/projections/w2_projection_layer.py index 5bb785c..49e714e 100644 --- a/metastable_projections/projections/w2_projection_layer.py +++ b/metastable_projections/projections/w2_projection_layer.py @@ -63,6 +63,9 @@ class WassersteinProjectionLayer(BaseProjectionLayer): else: proj_sqrt = sqrt + if has_diag_cov(p): + proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1) + proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt) return proj_p @@ -110,6 +113,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): mean, scale_tril=cov_sqrt) else: raise Exception('Dist-Type not implemented (of sb3 dist)') + p_out.cov_sqrt = cov_sqrt return p_out