Updated tests (no check kl for w2)
This commit is contained in:
parent
8e991ae05b
commit
9fb0014a99
@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
|
||||
assert jnp.all(jnp.isfinite(proj_params["scale"]))
|
||||
assert jnp.all(proj_params["scale"] > 0)
|
||||
|
||||
# Only check KL bounds for KL projection (and W2, which should approx hold as well)
|
||||
if ProjectionClass in [KLProjection, WassersteinProjection]:
|
||||
# Only check KL bounds for KL projection
|
||||
if ProjectionClass in [KLProjection]:
|
||||
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
|
||||
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user