Checked between changes from Fabian's Public / Private version

This commit is contained in:
Dominik Moritz Roth 2022-10-27 20:32:38 +02:00
parent fc542f21f8
commit 64626322bd
2 changed files with 3 additions and 3 deletions

View File

@ -166,7 +166,7 @@ def entropy_equality_projection(p: th.distributions.Normal,
def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor): def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor):
""" """
Stolen from Fabian's Code (Private Version) Stolen from Fabian's Code (Public Version)
Projects the mean based on the Mahalanobis objective and trust region. Projects the mean based on the Mahalanobis objective and trust region.
Args: Args:

View File

@ -4,7 +4,7 @@ from typing import Tuple, Any
from ..misc.norm import mahalanobis from ..misc.norm import mahalanobis
from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection from .base_projection_layer import BaseProjectionLayer, mean_projection
from ..misc.norm import mahalanobis, _batch_trace from ..misc.norm import mahalanobis, _batch_trace
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
@ -12,7 +12,7 @@ from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_
class WassersteinProjectionLayer(BaseProjectionLayer): class WassersteinProjectionLayer(BaseProjectionLayer):
""" """
Stolen from Fabian's Code (Private Version) Stolen from Fabian's Code (Public Version)
""" """
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):