diff --git a/metastable_baselines/misc/norm.py b/metastable_baselines/misc/norm.py index f40d319..894451b 100644 --- a/metastable_baselines/misc/norm.py +++ b/metastable_baselines/misc/norm.py @@ -3,6 +3,10 @@ from torch.distributions.multivariate_normal import _batch_mahalanobis def mahalanobis_alt(u, v, std): + """ + Stolen from Fabian's Code (Public Version) + + """ delta = u - v return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1])