diff --git a/tests/test_projections.py b/tests/test_projections.py new file mode 100644 index 0000000..222b0ea --- /dev/null +++ b/tests/test_projections.py @@ -0,0 +1,227 @@ +import jax +import jax.numpy as jnp +import pytest +from itpal_jax import ( + KLProjection, + WassersteinProjection, + FrobeniusProjection, + IdentityProjection +) +from typing import Dict + +@pytest.fixture +def gaussian_params(): + """Create test Gaussian parameters""" + loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) # batch_size=2, dim=2 + scale = jnp.array([[0.5, 0.5], [0.5, 0.5]]) + old_loc = jnp.zeros_like(loc) + old_scale = jnp.ones_like(scale) * 0.3 + + return { + "params": {"loc": loc, "scale": scale}, + "old_params": {"loc": old_loc, "scale": old_scale} + } + +def compute_gaussian_kl(p_params: Dict[str, jnp.ndarray], + q_params: Dict[str, jnp.ndarray], + full_cov: bool = False) -> jnp.ndarray: + """Compute KL divergence between two Gaussians.""" + mean1, mean2 = p_params["loc"], q_params["loc"] + k = mean1.shape[-1] + + if full_cov: + scale1, scale2 = p_params["scale_tril"], q_params["scale_tril"] + # Compute KL for full covariance case + cov1 = jnp.matmul(scale1, jnp.swapaxes(scale1, -1, -2)) + cov2 = jnp.matmul(scale2, jnp.swapaxes(scale2, -1, -2)) + + # Solve L x = v where L is Cholesky of cov2 + solved = jax.scipy.linalg.solve_triangular(scale2, scale1, lower=True) + trace_term = jnp.sum(solved ** 2, axis=(-2, -1)) + + # Log det terms + logdet1 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale1, axis1=-2, axis2=-1)), axis=-1) + logdet2 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale2, axis1=-2, axis2=-1)), axis=-1) + + # Mahalanobis term + diff = mean1 - mean2 + maha = jnp.sum(jnp.square(jax.scipy.linalg.solve_triangular( + scale2, diff[..., None], lower=True + ).squeeze(-1)), axis=-1) + + else: + scale1, scale2 = p_params["scale"], q_params["scale"] + # Compute KL for diagonal case + trace_term = jnp.sum((scale1 / scale2) ** 2, axis=-1) + logdet1 = 2 * jnp.sum(jnp.log(scale1), axis=-1) + logdet2 = 2 * jnp.sum(jnp.log(scale2), axis=-1) + maha = jnp.sum(jnp.square((mean1 - mean2) / scale2), axis=-1) + + return 0.5 * (trace_term - k + logdet2 - logdet1 + maha) + +def print_gaussian_params(prefix: str, params: Dict[str, jnp.ndarray]): + """Pretty print Gaussian parameters""" + print(f"\n{prefix}:") + print(f" loc: {params['loc']}") + if "scale" in params: + print(f" scale: {params['scale']}") + if "scale_tril" in params: + print(f" scale_tril:\n{params['scale_tril']}") + +@pytest.mark.parametrize("ProjectionClass,needs_cpp", [ + (KLProjection, True), + (WassersteinProjection, False), + (FrobeniusProjection, False), + (IdentityProjection, False) +]) +def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params): + """Test projections with diagonal covariance""" + if needs_cpp: + try: + import cpp_projection + except ImportError: + pytest.skip("cpp_projection not available") + + mean_bound, cov_bound = 0.1, 0.1 + proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound, + contextual_std=False, full_cov=False) + proj_params = proj.project(gaussian_params["params"], gaussian_params["old_params"]) + + # Check basic properties + assert proj_params["loc"].shape == gaussian_params["params"]["loc"].shape + assert proj_params["scale"].shape == gaussian_params["params"]["scale"].shape + assert jnp.all(jnp.isfinite(proj_params["loc"])) + assert jnp.all(jnp.isfinite(proj_params["scale"])) + assert jnp.all(proj_params["scale"] > 0) + + # Only check KL bounds for KL projection (and W2, which should approx hold as well) + if ProjectionClass in [KLProjection, WassersteinProjection]: + kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"]) + max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin + + if jnp.any(kl > max_kl): + print(f"\nProjection type: {ProjectionClass.__name__}") + print(f"KL divergence: {kl}") + print(f"Max allowed KL: {max_kl}") + print_gaussian_params("Original params", gaussian_params["params"]) + print_gaussian_params("Old params", gaussian_params["old_params"]) + print_gaussian_params("Projected params", proj_params) + assert False, f"KL divergence {kl} exceeds bound {max_kl}" + +@pytest.mark.parametrize("ProjectionClass", [KLProjection, FrobeniusProjection]) +def test_full_covariance_projection(ProjectionClass): + """Test projections with full covariance""" + try: + import cpp_projection + except ImportError: + pytest.skip("cpp_projection not available") + + # Create test parameters + loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) + scale_tril = jnp.array([ + [[0.5, 0.0], [0.1, 0.4]], + [[0.6, 0.0], [0.2, 0.3]] + ]) + params = {"loc": loc, "scale_tril": scale_tril} + + old_loc = jnp.zeros_like(loc) + old_scale_tril = jnp.array([ + [[0.3, 0.0], [0.0, 0.3]], + [[0.3, 0.0], [0.0, 0.3]] + ]) + old_params = {"loc": old_loc, "scale_tril": old_scale_tril} + + mean_bound, cov_bound = 0.1, 0.1 + proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound, + contextual_std=False, full_cov=True) + proj_params = proj.project(params, old_params) + + # Check basic properties + assert proj_params["loc"].shape == loc.shape + assert proj_params["scale_tril"].shape == scale_tril.shape + assert jnp.all(jnp.isfinite(proj_params["loc"])) + assert jnp.all(jnp.isfinite(proj_params["scale_tril"])) + + # Verify scale_tril is lower triangular + upper_tri = jnp.triu(proj_params["scale_tril"], k=1) + assert jnp.allclose(upper_tri, 0.0, atol=1e-5) + + # Verify positive definiteness + cov = jnp.matmul(proj_params["scale_tril"], jnp.swapaxes(proj_params["scale_tril"], -1, -2)) + eigvals = jnp.linalg.eigvalsh(cov) + assert jnp.all(eigvals > 0) + + # Only check KL bounds for KL projection + if ProjectionClass in [KLProjection]: + kl = compute_gaussian_kl(proj_params, old_params, full_cov=True) + max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin + assert jnp.all(kl <= max_kl), f"KL divergence {kl} exceeds bound {max_kl}" + +def test_contextual_vs_noncontextual(): + """Test difference between contextual and non-contextual standard deviations""" + loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) + scale = jnp.array([[0.5, 0.6], [0.7, 0.8]]) + params = {"loc": loc, "scale": scale} + + old_loc = jnp.zeros_like(loc) + old_scale = jnp.ones_like(scale) * 0.3 + old_params = {"loc": old_loc, "scale": old_scale} + + # Test with contextual=False + proj_noncontextual = FrobeniusProjection( + mean_bound=0.1, cov_bound=0.1, + contextual_std=False, full_cov=False + ) + proj_params_nonctx = proj_noncontextual.project(params, old_params) + + # Non-contextual should have same scale for all batch elements + scale_diff = jnp.std(proj_params_nonctx["scale"], axis=0) + assert jnp.allclose(scale_diff, 0.0, atol=1e-5) + +@pytest.mark.parametrize("ProjectionClass,needs_cpp", [ + (KLProjection, True), + (WassersteinProjection, False), + (FrobeniusProjection, False), + (IdentityProjection, False) +]) +def test_gradient_flow(ProjectionClass, needs_cpp): + """Test that gradients can flow through all projections""" + if needs_cpp: + try: + import cpp_projection + except ImportError: + pytest.skip("cpp_projection not available") + + # Create test parameters + loc = jnp.array([[1.0, -1.0]]) + scale = jnp.array([[0.5, 0.5]]) + params = {"loc": loc, "scale": scale} + + old_params = { + "loc": jnp.zeros_like(loc), + "scale": jnp.ones_like(scale) * 0.3 + } + + proj = ProjectionClass(mean_bound=0.1, cov_bound=0.1) + + # Test gradient through projection + def proj_loss_fn(p): + proj_p = proj.project(p, old_params) + return jnp.sum(proj_p["loc"]**2 + proj_p["scale"]**2) + + proj_grads = jax.grad(proj_loss_fn)(params) + + # Test gradient through trust region loss + def tr_loss_fn(p): + proj_p = proj.project(p, old_params) + return proj.get_trust_region_loss(p, proj_p) + + tr_grads = jax.grad(tr_loss_fn)(params) + + # Check gradients exist and have correct shape + for grads in [proj_grads, tr_grads]: + assert "loc" in grads and "scale" in grads + assert grads["loc"].shape == params["loc"].shape + assert grads["scale"].shape == params["scale"].shape + assert jnp.all(jnp.isfinite(grads["loc"])) + assert jnp.all(jnp.isfinite(grads["scale"])) \ No newline at end of file