13 lines
313 B
Python
13 lines
313 B
Python
|
import torch as th
|
||
|
from torch.distributions.multivariate_normal import _batch_mahalanobis
|
||
|
|
||
|
|
||
|
def mahalanobis_blub(u, v, std):
|
||
|
delta = u - v
|
||
|
return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1])
|
||
|
|
||
|
|
||
|
def mahalanobis(u, v, cov):
|
||
|
delta = u - v
|
||
|
return _batch_mahalanobis(cov, delta)
|