Factored out frob_sq and perf improvement for spd input

This commit is contained in:
Dominik Moritz Roth 2022-06-27 13:44:08 +02:00
parent f4c87c9cdc
commit 416dde202d

View File

@ -10,3 +10,18 @@ def mahalanobis_blub(u, v, std):
def mahalanobis(u, v, cov):
delta = u - v
return _batch_mahalanobis(cov, delta)
def frob_sq(diff, is_spd=False):
# If diff is spd, we can use a more perfromant algorithm
if is_spd:
return _frob_sq_spd(diff)
return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2)
def _frob_sq_spd(diff):
return _batch_trace(diff @ diff)
def _batch_trace(x):
return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)