Checked between changes from Fabian's Public / Private version

This commit is contained in:
Dominik Moritz Roth 2022-10-27 20:32:38 +02:00
parent fc542f21f8
commit 64626322bd
2 changed files with 3 additions and 3 deletions

View File

@ -166,7 +166,7 @@ def entropy_equality_projection(p: th.distributions.Normal,
def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor):
"""
Stolen from Fabian's Code (Private Version)
Stolen from Fabian's Code (Public Version)
Projects the mean based on the Mahalanobis objective and trust region.
Args:

View File

@ -4,7 +4,7 @@ from typing import Tuple, Any
from ..misc.norm import mahalanobis
from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection
from .base_projection_layer import BaseProjectionLayer, mean_projection
from ..misc.norm import mahalanobis, _batch_trace
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
@ -12,7 +12,7 @@ from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_
class WassersteinProjectionLayer(BaseProjectionLayer):
"""
Stolen from Fabian's Code (Private Version)
Stolen from Fabian's Code (Public Version)
"""
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):