From 416dde202d25f85a4abb23a1ef7974a022de704e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 27 Jun 2022 13:44:08 +0200 Subject: [PATCH] Factored out frob_sq and perf improvement for spd input --- sb3_trl/misc/norm.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sb3_trl/misc/norm.py b/sb3_trl/misc/norm.py index 4a9cbf2..5715334 100644 --- a/sb3_trl/misc/norm.py +++ b/sb3_trl/misc/norm.py @@ -10,3 +10,18 @@ def mahalanobis_blub(u, v, std): def mahalanobis(u, v, cov): delta = u - v return _batch_mahalanobis(cov, delta) + + +def frob_sq(diff, is_spd=False): + # If diff is spd, we can use a more perfromant algorithm + if is_spd: + return _frob_sq_spd(diff) + return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2) + + +def _frob_sq_spd(diff): + return _batch_trace(diff @ diff) + + +def _batch_trace(x): + return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)