Also check loss calc works for full cov case

This commit is contained in:
Dominik Moritz Roth 2024-12-21 18:53:27 +01:00
parent 3e2b988a2f
commit e83cb9a8a5

View File

@ -151,6 +151,10 @@ def test_full_covariance_projection(ProjectionClass):
eigvals = jnp.linalg.eigvalsh(cov)
assert jnp.all(eigvals > 0)
# Check trust region loss computation works
tr_loss = proj.get_trust_region_loss(params, proj_params)
assert jnp.isfinite(tr_loss)
# Only check KL bounds for KL projection
if ProjectionClass in [KLProjection]:
kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)