Fixed chol not expanding bug and function to shrink chol to diag

This commit is contained in:
Dominik Moritz Roth 2022-06-29 12:44:13 +02:00
parent 7c117cfca5
commit 4e77190d8e
2 changed files with 21 additions and 3 deletions

View File

@ -12,7 +12,7 @@ def get_mean_and_chol(p, expand=False):
elif isinstance(p, th.distributions.MultivariateNormal): elif isinstance(p, th.distributions.MultivariateNormal):
return p.mean, p.scale_tril return p.mean, p.scale_tril
elif isinstance(p, SB3_Distribution): elif isinstance(p, SB3_Distribution):
return get_mean_and_chol(p.distribution) return get_mean_and_chol(p.distribution, expand=expand)
else: else:
raise Exception('Dist-Type not implemented') raise Exception('Dist-Type not implemented')
@ -28,6 +28,24 @@ def get_cov(p):
raise Exception('Dist-Type not implemented') raise Exception('Dist-Type not implemented')
def has_diag_cov(p, numerical_check=True):
if isinstance(p, SB3_Distribution):
return has_diag_cov(p.distribution, numerical_check=numerical_check)
if isinstance(p, th.distributions.Normal):
return True
if not numerical_check:
return False
# Check if matrix is diag
cov = get_cov(p)
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
if check_diag and not has_diag_cov(p):
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
def new_dist_like(orig_p, mean, chol): def new_dist_like(orig_p, mean, chol):
if isinstance(orig_p, th.distributions.Normal): if isinstance(orig_p, th.distributions.Normal):
if orig_p.stddev.shape != chol.shape: if orig_p.stddev.shape != chol.shape:

View File

@ -2,7 +2,7 @@ import torch as th
from torch.distributions.multivariate_normal import _batch_mahalanobis from torch.distributions.multivariate_normal import _batch_mahalanobis
def mahalanobis_blub(u, v, std): def mahalanobis_alt(u, v, std):
delta = u - v delta = u - v
return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1]) return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1])
@ -13,7 +13,7 @@ def mahalanobis(u, v, cov):
def frob_sq(diff, is_spd=False): def frob_sq(diff, is_spd=False):
# If diff is spd, we can use a more perfromant algorithm # If diff is spd, we can use a (probably) more performant algorithm
if is_spd: if is_spd:
return _frob_sq_spd(diff) return _frob_sq_spd(diff)
return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2) return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2)