Also check loss calc works for full cov case
This commit is contained in:
parent
3e2b988a2f
commit
e83cb9a8a5
@ -151,6 +151,10 @@ def test_full_covariance_projection(ProjectionClass):
|
||||
eigvals = jnp.linalg.eigvalsh(cov)
|
||||
assert jnp.all(eigvals > 0)
|
||||
|
||||
# Check trust region loss computation works
|
||||
tr_loss = proj.get_trust_region_loss(params, proj_params)
|
||||
assert jnp.isfinite(tr_loss)
|
||||
|
||||
# Only check KL bounds for KL projection
|
||||
if ProjectionClass in [KLProjection]:
|
||||
kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)
|
||||
|
Loading…
Reference in New Issue
Block a user