import torch as th from torch.distributions.multivariate_normal import _batch_mahalanobis def mahalanobis_alt(u, v, std): """ Stolen from Fabian's Code (Public Version) """ delta = u - v return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1]) def mahalanobis(u, v, chol): delta = u - v return _batch_mahalanobis(chol, delta) def frob_sq(diff, is_spd=False): # If diff is spd, we can use a (probably) more performant algorithm if is_spd: return _frob_sq_spd(diff) return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2) def _frob_sq_spd(diff): return _batch_trace(diff @ diff) def _batch_trace(x): return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)