metastable-baselines/metastable_baselines/misc/norm.py

32 lines
746 B
Python
Raw Normal View History

2022-06-26 16:39:37 +02:00
import torch as th
from torch.distributions.multivariate_normal import _batch_mahalanobis
def mahalanobis_alt(u, v, std):
2022-07-01 11:52:50 +02:00
"""
Stolen from Fabian's Code (Public Version)
"""
2022-06-26 16:39:37 +02:00
delta = u - v
return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1])
2022-06-30 20:40:30 +02:00
def mahalanobis(u, v, chol):
2022-06-26 16:39:37 +02:00
delta = u - v
2022-06-30 20:40:30 +02:00
return _batch_mahalanobis(chol, delta)
def frob_sq(diff, is_spd=False):
# If diff is spd, we can use a (probably) more performant algorithm
if is_spd:
return _frob_sq_spd(diff)
return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2)
def _frob_sq_spd(diff):
return _batch_trace(diff @ diff)
def _batch_trace(x):
return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)