Fix issues with full cov

This commit is contained in:
Dominik Moritz Roth 2024-04-01 00:03:10 +02:00
parent f129928635
commit 6441bbfc5b

View File

@ -27,23 +27,8 @@ def get_mean_and_chol(p: AnyDistribution, expand=False):
def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False):
mean, chol = get_mean_and_chol(p, expand=False) mean, chol = get_mean_and_chol(p, expand=expand)
if not hasattr(p, 'cov_sqrt'): return mean, chol
if len(p.distribution.scale.shape)==2: # In the factorized case, chol = matSqrt
sqrt_cov = p.distribution.scale
else:
raise Exception(
'Distribution was not induced from sqrt. On-demand calculation is not supported.')
else:
sqrt_cov = p.cov_sqrt
if mean.shape[0] != sqrt_cov.shape[0]:
shape = list(sqrt_cov.shape)
shape[0] = mean.shape[0]
shape = tuple(shape)
sqrt_cov = sqrt_cov.expand(shape)
if expand and len(sqrt_cov.shape) <= 2:
sqrt_cov = th.diag_embed(sqrt_cov)
return mean, sqrt_cov
def get_cov(p: AnyDistribution): def get_cov(p: AnyDistribution):
@ -111,29 +96,3 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
return p_out return p_out
else: else:
raise Exception('Dist-Type not implemented') raise Exception('Dist-Type not implemented')
def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor):
chol = _sqrt_to_chol(cov_sqrt, only_diag=has_diag_cov(orig_p))
new = new_dist_like(orig_p, mean, chol)
new.cov_sqrt = cov_sqrt
if hasattr(new, 'distribution'):
new.distribution.cov_sqrt = cov_sqrt
return new
def _sqrt_to_chol(cov_sqrt, only_diag=False):
if only_diag:
return cov_sqrt
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6)
chol = th.linalg.cholesky(cov)
#if only_diag:
# chol = th.diagonal(chol, dim1=-2, dim2=-1)
return chol