From 64626322bd17fc960ab0b60b17d43e35e709e810 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 27 Oct 2022 20:32:38 +0200 Subject: [PATCH] Checked between changes from Fabian's Public / Private version --- metastable_projections/projections/base_projection_layer.py | 2 +- metastable_projections/projections/w2_projection_layer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/metastable_projections/projections/base_projection_layer.py b/metastable_projections/projections/base_projection_layer.py index 6809383..31950d7 100644 --- a/metastable_projections/projections/base_projection_layer.py +++ b/metastable_projections/projections/base_projection_layer.py @@ -166,7 +166,7 @@ def entropy_equality_projection(p: th.distributions.Normal, def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor): """ - Stolen from Fabian's Code (Private Version) + Stolen from Fabian's Code (Public Version) Projects the mean based on the Mahalanobis objective and trust region. Args: diff --git a/metastable_projections/projections/w2_projection_layer.py b/metastable_projections/projections/w2_projection_layer.py index 2bb6321..b2065ab 100644 --- a/metastable_projections/projections/w2_projection_layer.py +++ b/metastable_projections/projections/w2_projection_layer.py @@ -4,7 +4,7 @@ from typing import Tuple, Any from ..misc.norm import mahalanobis -from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection +from .base_projection_layer import BaseProjectionLayer, mean_projection from ..misc.norm import mahalanobis, _batch_trace from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov @@ -12,7 +12,7 @@ from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_ class WassersteinProjectionLayer(BaseProjectionLayer): """ - Stolen from Fabian's Code (Private Version) + Stolen from Fabian's Code (Public Version) """ def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):