Refactor w2 proj
This commit is contained in:
parent
6441bbfc5b
commit
f1c047d387
@ -7,7 +7,7 @@ from ..misc.norm import mahalanobis
|
||||
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
||||
|
||||
from ..misc.norm import mahalanobis, _batch_trace
|
||||
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
|
||||
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, has_diag_cov
|
||||
|
||||
from stable_baselines3.common.distributions import Distribution
|
||||
|
||||
@ -31,6 +31,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||
Returns:
|
||||
mean, cov sqrt
|
||||
"""
|
||||
|
||||
mean, sqrt = get_mean_and_sqrt(p, expand=True)
|
||||
old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True)
|
||||
batch_shape = mean.shape[:-1]
|
||||
@ -66,7 +67,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||
if has_diag_cov(p):
|
||||
proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1)
|
||||
|
||||
proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt)
|
||||
proj_p = self.new_dist_like(p, proj_mean, proj_sqrt)
|
||||
return proj_p
|
||||
|
||||
def trust_region_value(self, p, q):
|
||||
@ -90,7 +91,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||
# projected distribution
|
||||
|
||||
proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p)
|
||||
p_target = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt)
|
||||
p_target = self.new_dist_like(p, proj_mean, proj_sqrt)
|
||||
kl_diff = self.trust_region_value(p, p_target)
|
||||
|
||||
kl_loss = kl_diff.mean()
|
||||
@ -110,7 +111,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||
p_out = orig_p.__class__(orig_p.action_dim)
|
||||
p_out.distribution = th.distributions.MultivariateNormal(
|
||||
mean, scale_tril=cov_sqrt)
|
||||
mean, scale_tril=cov_sqrt, validate_args=False)
|
||||
else:
|
||||
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
||||
p_out.cov_sqrt = cov_sqrt
|
||||
@ -157,11 +158,7 @@ def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor,
|
||||
else:
|
||||
# W2 objective for cov assuming normal W2 objective for mean
|
||||
cov_other = get_cov(q)
|
||||
try:
|
||||
cov_part = _batch_trace(
|
||||
cov_other + cov - 2 * th.bmm(sqrt_other, sqrt))
|
||||
except:
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
|
||||
return mean_part, cov_part
|
||||
|
Loading…
Reference in New Issue
Block a user