Refactor w2 proj

This commit is contained in:
Dominik Moritz Roth 2024-04-01 00:03:27 +02:00
parent 6441bbfc5b
commit f1c047d387

View File

@ -7,7 +7,7 @@ from ..misc.norm import mahalanobis
from .base_projection_layer import BaseProjectionLayer, mean_projection from .base_projection_layer import BaseProjectionLayer, mean_projection
from ..misc.norm import mahalanobis, _batch_trace from ..misc.norm import mahalanobis, _batch_trace
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, has_diag_cov
from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.distributions import Distribution
@ -31,6 +31,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
Returns: Returns:
mean, cov sqrt mean, cov sqrt
""" """
mean, sqrt = get_mean_and_sqrt(p, expand=True) mean, sqrt = get_mean_and_sqrt(p, expand=True)
old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True) old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True)
batch_shape = mean.shape[:-1] batch_shape = mean.shape[:-1]
@ -66,7 +67,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
if has_diag_cov(p): if has_diag_cov(p):
proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1) proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1)
proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt) proj_p = self.new_dist_like(p, proj_mean, proj_sqrt)
return proj_p return proj_p
def trust_region_value(self, p, q): def trust_region_value(self, p, q):
@ -90,7 +91,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
# projected distribution # projected distribution
proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p) proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p)
p_target = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt) p_target = self.new_dist_like(p, proj_mean, proj_sqrt)
kl_diff = self.trust_region_value(p, p_target) kl_diff = self.trust_region_value(p, p_target)
kl_loss = kl_diff.mean() kl_loss = kl_diff.mean()
@ -110,7 +111,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
elif isinstance(p, th.distributions.MultivariateNormal): elif isinstance(p, th.distributions.MultivariateNormal):
p_out = orig_p.__class__(orig_p.action_dim) p_out = orig_p.__class__(orig_p.action_dim)
p_out.distribution = th.distributions.MultivariateNormal( p_out.distribution = th.distributions.MultivariateNormal(
mean, scale_tril=cov_sqrt) mean, scale_tril=cov_sqrt, validate_args=False)
else: else:
raise Exception('Dist-Type not implemented (of sb3 dist)') raise Exception('Dist-Type not implemented (of sb3 dist)')
p_out.cov_sqrt = cov_sqrt p_out.cov_sqrt = cov_sqrt
@ -157,11 +158,7 @@ def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor,
else: else:
# W2 objective for cov assuming normal W2 objective for mean # W2 objective for cov assuming normal W2 objective for mean
cov_other = get_cov(q) cov_other = get_cov(q)
try: cov_part = _batch_trace(
cov_part = _batch_trace( cov_other + cov - 2 * th.bmm(sqrt_other, sqrt))
cov_other + cov - 2 * th.bmm(sqrt_other, sqrt))
except:
import pdb
pdb.set_trace()
return mean_part, cov_part return mean_part, cov_part