diff --git a/sb3_trl/misc/distTools.py b/sb3_trl/misc/distTools.py index e7ee3e0..6496ea8 100644 --- a/sb3_trl/misc/distTools.py +++ b/sb3_trl/misc/distTools.py @@ -12,7 +12,7 @@ def get_mean_and_chol(p, expand=False): elif isinstance(p, th.distributions.MultivariateNormal): return p.mean, p.scale_tril elif isinstance(p, SB3_Distribution): - return get_mean_and_chol(p.distribution) + return get_mean_and_chol(p.distribution, expand=expand) else: raise Exception('Dist-Type not implemented') @@ -28,6 +28,24 @@ def get_cov(p): raise Exception('Dist-Type not implemented') +def has_diag_cov(p, numerical_check=True): + if isinstance(p, SB3_Distribution): + return has_diag_cov(p.distribution, numerical_check=numerical_check) + if isinstance(p, th.distributions.Normal): + return True + if not numerical_check: + return False + # Check if matrix is diag + cov = get_cov(p) + return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov))) + + +def get_diag_cov_vec(p, check_diag=True, numerical_check=True): + if check_diag and not has_diag_cov(p): + raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal') + return th.diagonal(get_cov(p), dim1=-2, dim2=-1) + + def new_dist_like(orig_p, mean, chol): if isinstance(orig_p, th.distributions.Normal): if orig_p.stddev.shape != chol.shape: diff --git a/sb3_trl/misc/norm.py b/sb3_trl/misc/norm.py index 5715334..74d1f66 100644 --- a/sb3_trl/misc/norm.py +++ b/sb3_trl/misc/norm.py @@ -2,7 +2,7 @@ import torch as th from torch.distributions.multivariate_normal import _batch_mahalanobis -def mahalanobis_blub(u, v, std): +def mahalanobis_alt(u, v, std): delta = u - v return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1]) @@ -13,7 +13,7 @@ def mahalanobis(u, v, cov): def frob_sq(diff, is_spd=False): - # If diff is spd, we can use a more perfromant algorithm + # If diff is spd, we can use a (probably) more performant algorithm if is_spd: return _frob_sq_spd(diff) return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2)