From f1c047d38766297fc16eb4a235909ae9e6a4dc3d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 1 Apr 2024 00:03:27 +0200 Subject: [PATCH] Refactor w2 proj --- .../projections/w2_projection_layer.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/metastable_projections/projections/w2_projection_layer.py b/metastable_projections/projections/w2_projection_layer.py index 49e714e..7296f28 100644 --- a/metastable_projections/projections/w2_projection_layer.py +++ b/metastable_projections/projections/w2_projection_layer.py @@ -7,7 +7,7 @@ from ..misc.norm import mahalanobis from .base_projection_layer import BaseProjectionLayer, mean_projection from ..misc.norm import mahalanobis, _batch_trace -from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov +from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, has_diag_cov from stable_baselines3.common.distributions import Distribution @@ -31,6 +31,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): Returns: mean, cov sqrt """ + mean, sqrt = get_mean_and_sqrt(p, expand=True) old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True) batch_shape = mean.shape[:-1] @@ -66,7 +67,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): if has_diag_cov(p): proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1) - proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt) + proj_p = self.new_dist_like(p, proj_mean, proj_sqrt) return proj_p def trust_region_value(self, p, q): @@ -90,7 +91,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): # projected distribution proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p) - p_target = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt) + p_target = self.new_dist_like(p, proj_mean, proj_sqrt) kl_diff = self.trust_region_value(p, p_target) kl_loss = kl_diff.mean() @@ -110,7 +111,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): elif isinstance(p, th.distributions.MultivariateNormal): p_out = orig_p.__class__(orig_p.action_dim) p_out.distribution = th.distributions.MultivariateNormal( - mean, scale_tril=cov_sqrt) + mean, scale_tril=cov_sqrt, validate_args=False) else: raise Exception('Dist-Type not implemented (of sb3 dist)') p_out.cov_sqrt = cov_sqrt @@ -157,11 +158,7 @@ def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor, else: # W2 objective for cov assuming normal W2 objective for mean cov_other = get_cov(q) - try: - cov_part = _batch_trace( - cov_other + cov - 2 * th.bmm(sqrt_other, sqrt)) - except: - import pdb - pdb.set_trace() + cov_part = _batch_trace( + cov_other + cov - 2 * th.bmm(sqrt_other, sqrt)) return mean_part, cov_part