Checked between changes from Fabian's Public / Private version
This commit is contained in:
parent
fc542f21f8
commit
64626322bd
@ -166,7 +166,7 @@ def entropy_equality_projection(p: th.distributions.Normal,
|
||||
|
||||
def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor):
|
||||
"""
|
||||
Stolen from Fabian's Code (Private Version)
|
||||
Stolen from Fabian's Code (Public Version)
|
||||
|
||||
Projects the mean based on the Mahalanobis objective and trust region.
|
||||
Args:
|
||||
|
@ -4,7 +4,7 @@ from typing import Tuple, Any
|
||||
|
||||
from ..misc.norm import mahalanobis
|
||||
|
||||
from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection
|
||||
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
||||
|
||||
from ..misc.norm import mahalanobis, _batch_trace
|
||||
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
|
||||
@ -12,7 +12,7 @@ from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_
|
||||
|
||||
class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||
"""
|
||||
Stolen from Fabian's Code (Private Version)
|
||||
Stolen from Fabian's Code (Public Version)
|
||||
"""
|
||||
|
||||
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user