Add norms
This commit is contained in:
parent
c6a12aa27b
commit
add8e92b4a
27
fancy_rl/norm.py
Normal file
27
fancy_rl/norm.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import torch as th
|
||||||
|
from torch.distributions.multivariate_normal import _batch_mahalanobis
|
||||||
|
|
||||||
|
|
||||||
|
def mahalanobis_alt(u, v, std):
|
||||||
|
delta = u - v
|
||||||
|
return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1])
|
||||||
|
|
||||||
|
|
||||||
|
def mahalanobis(u, v, chol):
|
||||||
|
delta = u - v
|
||||||
|
return _batch_mahalanobis(chol, delta)
|
||||||
|
|
||||||
|
|
||||||
|
def frob_sq(diff, is_spd=False):
|
||||||
|
# If diff is spd, we can use a (probably) more performant algorithm
|
||||||
|
if is_spd:
|
||||||
|
return _frob_sq_spd(diff)
|
||||||
|
return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
def _frob_sq_spd(diff):
|
||||||
|
return _batch_trace(diff @ diff)
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_trace(x):
|
||||||
|
return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)
|
Loading…
Reference in New Issue
Block a user