Factored out frob_sq and perf improvement for spd input
This commit is contained in:
parent
f4c87c9cdc
commit
416dde202d
@ -10,3 +10,18 @@ def mahalanobis_blub(u, v, std):
|
||||
def mahalanobis(u, v, cov):
|
||||
delta = u - v
|
||||
return _batch_mahalanobis(cov, delta)
|
||||
|
||||
|
||||
def frob_sq(diff, is_spd=False):
|
||||
# If diff is spd, we can use a more perfromant algorithm
|
||||
if is_spd:
|
||||
return _frob_sq_spd(diff)
|
||||
return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2)
|
||||
|
||||
|
||||
def _frob_sq_spd(diff):
|
||||
return _batch_trace(diff @ diff)
|
||||
|
||||
|
||||
def _batch_trace(x):
|
||||
return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)
|
||||
|
Loading…
Reference in New Issue
Block a user