Checked between changes from Fabian's Public / Private version
This commit is contained in:
parent
fc542f21f8
commit
64626322bd
@ -166,7 +166,7 @@ def entropy_equality_projection(p: th.distributions.Normal,
|
|||||||
|
|
||||||
def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor):
|
def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor):
|
||||||
"""
|
"""
|
||||||
Stolen from Fabian's Code (Private Version)
|
Stolen from Fabian's Code (Public Version)
|
||||||
|
|
||||||
Projects the mean based on the Mahalanobis objective and trust region.
|
Projects the mean based on the Mahalanobis objective and trust region.
|
||||||
Args:
|
Args:
|
||||||
|
@ -4,7 +4,7 @@ from typing import Tuple, Any
|
|||||||
|
|
||||||
from ..misc.norm import mahalanobis
|
from ..misc.norm import mahalanobis
|
||||||
|
|
||||||
from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection
|
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
||||||
|
|
||||||
from ..misc.norm import mahalanobis, _batch_trace
|
from ..misc.norm import mahalanobis, _batch_trace
|
||||||
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
|
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
|
||||||
@ -12,7 +12,7 @@ from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_
|
|||||||
|
|
||||||
class WassersteinProjectionLayer(BaseProjectionLayer):
|
class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||||
"""
|
"""
|
||||||
Stolen from Fabian's Code (Private Version)
|
Stolen from Fabian's Code (Public Version)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user