Updated tests (no check kl for w2)

This commit is contained in:
Dominik Moritz Roth 2024-12-21 18:31:07 +01:00
parent 8e991ae05b
commit 9fb0014a99

View File

@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
assert jnp.all(jnp.isfinite(proj_params["scale"])) assert jnp.all(jnp.isfinite(proj_params["scale"]))
assert jnp.all(proj_params["scale"] > 0) assert jnp.all(proj_params["scale"] > 0)
# Only check KL bounds for KL projection (and W2, which should approx hold as well) # Only check KL bounds for KL projection
if ProjectionClass in [KLProjection, WassersteinProjection]: if ProjectionClass in [KLProjection]:
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"]) kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin