From e83cb9a8a551d8b90a94600950fe31c1572d7694 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 21 Dec 2024 18:53:27 +0100 Subject: [PATCH] Also check loss calc works for full cov case --- tests/test_projections.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_projections.py b/tests/test_projections.py index d71ad04..098b8fd 100644 --- a/tests/test_projections.py +++ b/tests/test_projections.py @@ -151,6 +151,10 @@ def test_full_covariance_projection(ProjectionClass): eigvals = jnp.linalg.eigvalsh(cov) assert jnp.all(eigvals > 0) + # Check trust region loss computation works + tr_loss = proj.get_trust_region_loss(params, proj_params) + assert jnp.isfinite(tr_loss) + # Only check KL bounds for KL projection if ProjectionClass in [KLProjection]: kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)