diff --git a/fancy_rl/norm.py b/fancy_rl/norm.py new file mode 100644 index 0000000..f40d319 --- /dev/null +++ b/fancy_rl/norm.py @@ -0,0 +1,27 @@ +import torch as th +from torch.distributions.multivariate_normal import _batch_mahalanobis + + +def mahalanobis_alt(u, v, std): + delta = u - v + return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1]) + + +def mahalanobis(u, v, chol): + delta = u - v + return _batch_mahalanobis(chol, delta) + + +def frob_sq(diff, is_spd=False): + # If diff is spd, we can use a (probably) more performant algorithm + if is_spd: + return _frob_sq_spd(diff) + return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2) + + +def _frob_sq_spd(diff): + return _batch_trace(diff @ diff) + + +def _batch_trace(x): + return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)