Fix issues when providing diagonal matrix sqrt
This commit is contained in:
parent
7ad4858e8a
commit
f129928635
@ -27,20 +27,23 @@ def get_mean_and_chol(p: AnyDistribution, expand=False):
|
|||||||
|
|
||||||
|
|
||||||
def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False):
|
def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False):
|
||||||
|
mean, chol = get_mean_and_chol(p, expand=False)
|
||||||
if not hasattr(p, 'cov_sqrt'):
|
if not hasattr(p, 'cov_sqrt'):
|
||||||
raise Exception(
|
if len(p.distribution.scale.shape)==2: # In the factorized case, chol = matSqrt
|
||||||
'Distribution was not induced from sqrt. On-demand calculation is not supported.')
|
sqrt_cov = p.distribution.scale
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
'Distribution was not induced from sqrt. On-demand calculation is not supported.')
|
||||||
else:
|
else:
|
||||||
mean, chol = get_mean_and_chol(p, expand=False)
|
|
||||||
sqrt_cov = p.cov_sqrt
|
sqrt_cov = p.cov_sqrt
|
||||||
if mean.shape[0] != sqrt_cov.shape[0]:
|
if mean.shape[0] != sqrt_cov.shape[0]:
|
||||||
shape = list(sqrt_cov.shape)
|
shape = list(sqrt_cov.shape)
|
||||||
shape[0] = mean.shape[0]
|
shape[0] = mean.shape[0]
|
||||||
shape = tuple(shape)
|
shape = tuple(shape)
|
||||||
sqrt_cov = sqrt_cov.expand(shape)
|
sqrt_cov = sqrt_cov.expand(shape)
|
||||||
if expand and len(sqrt_cov.shape) <= 2:
|
if expand and len(sqrt_cov.shape) <= 2:
|
||||||
sqrt_cov = th.diag_embed(sqrt_cov)
|
sqrt_cov = th.diag_embed(sqrt_cov)
|
||||||
return mean, sqrt_cov
|
return mean, sqrt_cov
|
||||||
|
|
||||||
|
|
||||||
def get_cov(p: AnyDistribution):
|
def get_cov(p: AnyDistribution):
|
||||||
@ -123,12 +126,14 @@ def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt:
|
|||||||
|
|
||||||
|
|
||||||
def _sqrt_to_chol(cov_sqrt, only_diag=False):
|
def _sqrt_to_chol(cov_sqrt, only_diag=False):
|
||||||
|
if only_diag:
|
||||||
|
return cov_sqrt
|
||||||
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
||||||
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
|
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
|
||||||
|
|
||||||
chol = th.linalg.cholesky(cov)
|
chol = th.linalg.cholesky(cov)
|
||||||
|
|
||||||
if only_diag:
|
#if only_diag:
|
||||||
chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
# chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
return chol
|
return chol
|
||||||
|
@ -86,6 +86,7 @@ class BaseProjectionLayer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
projected_dist, old_dist (from rollouts)
|
projected_dist, old_dist (from rollouts)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp)
|
old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp)
|
||||||
return self(dist, old_distribution, **kwargs), old_distribution
|
return self(dist, old_distribution, **kwargs), old_distribution
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Tuple
|
|||||||
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
||||||
|
|
||||||
from ..misc.norm import mahalanobis, frob_sq
|
from ..misc.norm import mahalanobis, frob_sq
|
||||||
from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like
|
from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like, has_diag_cov
|
||||||
|
|
||||||
|
|
||||||
class FrobeniusProjectionLayer(BaseProjectionLayer):
|
class FrobeniusProjectionLayer(BaseProjectionLayer):
|
||||||
@ -57,6 +57,9 @@ class FrobeniusProjectionLayer(BaseProjectionLayer):
|
|||||||
else:
|
else:
|
||||||
proj_chol = chol
|
proj_chol = chol
|
||||||
|
|
||||||
|
if has_diag_cov(p):
|
||||||
|
proj_chol = th.diagonal(proj_chol, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
proj_p = new_dist_like(p, proj_mean, proj_chol)
|
proj_p = new_dist_like(p, proj_mean, proj_chol)
|
||||||
return proj_p
|
return proj_p
|
||||||
|
|
||||||
|
@ -63,6 +63,9 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
else:
|
else:
|
||||||
proj_sqrt = sqrt
|
proj_sqrt = sqrt
|
||||||
|
|
||||||
|
if has_diag_cov(p):
|
||||||
|
proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt)
|
proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt)
|
||||||
return proj_p
|
return proj_p
|
||||||
|
|
||||||
@ -110,6 +113,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
mean, scale_tril=cov_sqrt)
|
mean, scale_tril=cov_sqrt)
|
||||||
else:
|
else:
|
||||||
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
||||||
|
p_out.cov_sqrt = cov_sqrt
|
||||||
return p_out
|
return p_out
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user