Also check loss calc works for full cov case
This commit is contained in:
parent
3e2b988a2f
commit
e83cb9a8a5
@ -151,6 +151,10 @@ def test_full_covariance_projection(ProjectionClass):
|
|||||||
eigvals = jnp.linalg.eigvalsh(cov)
|
eigvals = jnp.linalg.eigvalsh(cov)
|
||||||
assert jnp.all(eigvals > 0)
|
assert jnp.all(eigvals > 0)
|
||||||
|
|
||||||
|
# Check trust region loss computation works
|
||||||
|
tr_loss = proj.get_trust_region_loss(params, proj_params)
|
||||||
|
assert jnp.isfinite(tr_loss)
|
||||||
|
|
||||||
# Only check KL bounds for KL projection
|
# Only check KL bounds for KL projection
|
||||||
if ProjectionClass in [KLProjection]:
|
if ProjectionClass in [KLProjection]:
|
||||||
kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)
|
kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user