Refactor w2 proj
This commit is contained in:
parent
6441bbfc5b
commit
f1c047d387
@ -7,7 +7,7 @@ from ..misc.norm import mahalanobis
|
|||||||
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
||||||
|
|
||||||
from ..misc.norm import mahalanobis, _batch_trace
|
from ..misc.norm import mahalanobis, _batch_trace
|
||||||
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
|
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, has_diag_cov
|
||||||
|
|
||||||
from stable_baselines3.common.distributions import Distribution
|
from stable_baselines3.common.distributions import Distribution
|
||||||
|
|
||||||
@ -31,6 +31,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
Returns:
|
Returns:
|
||||||
mean, cov sqrt
|
mean, cov sqrt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mean, sqrt = get_mean_and_sqrt(p, expand=True)
|
mean, sqrt = get_mean_and_sqrt(p, expand=True)
|
||||||
old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True)
|
old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True)
|
||||||
batch_shape = mean.shape[:-1]
|
batch_shape = mean.shape[:-1]
|
||||||
@ -66,7 +67,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
if has_diag_cov(p):
|
if has_diag_cov(p):
|
||||||
proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1)
|
proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt)
|
proj_p = self.new_dist_like(p, proj_mean, proj_sqrt)
|
||||||
return proj_p
|
return proj_p
|
||||||
|
|
||||||
def trust_region_value(self, p, q):
|
def trust_region_value(self, p, q):
|
||||||
@ -90,7 +91,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
# projected distribution
|
# projected distribution
|
||||||
|
|
||||||
proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p)
|
proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p)
|
||||||
p_target = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt)
|
p_target = self.new_dist_like(p, proj_mean, proj_sqrt)
|
||||||
kl_diff = self.trust_region_value(p, p_target)
|
kl_diff = self.trust_region_value(p, p_target)
|
||||||
|
|
||||||
kl_loss = kl_diff.mean()
|
kl_loss = kl_diff.mean()
|
||||||
@ -110,7 +111,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
p_out = orig_p.__class__(orig_p.action_dim)
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
p_out.distribution = th.distributions.MultivariateNormal(
|
p_out.distribution = th.distributions.MultivariateNormal(
|
||||||
mean, scale_tril=cov_sqrt)
|
mean, scale_tril=cov_sqrt, validate_args=False)
|
||||||
else:
|
else:
|
||||||
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
||||||
p_out.cov_sqrt = cov_sqrt
|
p_out.cov_sqrt = cov_sqrt
|
||||||
@ -157,11 +158,7 @@ def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor,
|
|||||||
else:
|
else:
|
||||||
# W2 objective for cov assuming normal W2 objective for mean
|
# W2 objective for cov assuming normal W2 objective for mean
|
||||||
cov_other = get_cov(q)
|
cov_other = get_cov(q)
|
||||||
try:
|
|
||||||
cov_part = _batch_trace(
|
cov_part = _batch_trace(
|
||||||
cov_other + cov - 2 * th.bmm(sqrt_other, sqrt))
|
cov_other + cov - 2 * th.bmm(sqrt_other, sqrt))
|
||||||
except:
|
|
||||||
import pdb
|
|
||||||
pdb.set_trace()
|
|
||||||
|
|
||||||
return mean_part, cov_part
|
return mean_part, cov_part
|
||||||
|
Loading…
Reference in New Issue
Block a user