import jax import jax.numpy as jnp import pytest from itpal_jax import ( KLProjection, WassersteinProjection, FrobeniusProjection, IdentityProjection ) from typing import Dict @pytest.fixture def gaussian_params(): """Create test Gaussian parameters""" loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) # batch_size=2, dim=2 scale = jnp.array([[0.5, 0.5], [0.5, 0.5]]) old_loc = jnp.zeros_like(loc) old_scale = jnp.ones_like(scale) * 0.3 return { "params": {"loc": loc, "scale": scale}, "old_params": {"loc": old_loc, "scale": old_scale} } def compute_gaussian_kl(p_params: Dict[str, jnp.ndarray], q_params: Dict[str, jnp.ndarray], full_cov: bool = False) -> jnp.ndarray: """Compute KL divergence between two Gaussians.""" mean1, mean2 = p_params["loc"], q_params["loc"] k = mean1.shape[-1] if full_cov: scale1, scale2 = p_params["scale_tril"], q_params["scale_tril"] # Compute KL for full covariance case cov1 = jnp.matmul(scale1, jnp.swapaxes(scale1, -1, -2)) cov2 = jnp.matmul(scale2, jnp.swapaxes(scale2, -1, -2)) # Solve L x = v where L is Cholesky of cov2 solved = jax.scipy.linalg.solve_triangular(scale2, scale1, lower=True) trace_term = jnp.sum(solved ** 2, axis=(-2, -1)) # Log det terms logdet1 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale1, axis1=-2, axis2=-1)), axis=-1) logdet2 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale2, axis1=-2, axis2=-1)), axis=-1) # Mahalanobis term diff = mean1 - mean2 maha = jnp.sum(jnp.square(jax.scipy.linalg.solve_triangular( scale2, diff[..., None], lower=True ).squeeze(-1)), axis=-1) else: scale1, scale2 = p_params["scale"], q_params["scale"] # Compute KL for diagonal case trace_term = jnp.sum((scale1 / scale2) ** 2, axis=-1) logdet1 = 2 * jnp.sum(jnp.log(scale1), axis=-1) logdet2 = 2 * jnp.sum(jnp.log(scale2), axis=-1) maha = jnp.sum(jnp.square((mean1 - mean2) / scale2), axis=-1) return 0.5 * (trace_term - k + logdet2 - logdet1 + maha) def print_gaussian_params(prefix: str, params: Dict[str, jnp.ndarray]): """Pretty print Gaussian parameters""" print(f"\n{prefix}:") print(f" loc: {params['loc']}") if "scale" in params: print(f" scale: {params['scale']}") if "scale_tril" in params: print(f" scale_tril:\n{params['scale_tril']}") @pytest.mark.parametrize("ProjectionClass,needs_cpp", [ (KLProjection, True), (WassersteinProjection, False), (FrobeniusProjection, False), (IdentityProjection, False) ]) def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params): """Test projections with diagonal covariance""" if needs_cpp: try: import cpp_projection except ImportError: pytest.skip("cpp_projection not available") mean_bound, cov_bound = 0.1, 0.1 proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=False, full_cov=False) proj_params = proj.project(gaussian_params["params"], gaussian_params["old_params"]) # Check basic properties assert proj_params["loc"].shape == gaussian_params["params"]["loc"].shape assert proj_params["scale"].shape == gaussian_params["params"]["scale"].shape assert jnp.all(jnp.isfinite(proj_params["loc"])) assert jnp.all(jnp.isfinite(proj_params["scale"])) assert jnp.all(proj_params["scale"] > 0) # Only check KL bounds for KL projection if ProjectionClass in [KLProjection]: kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"]) max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin if jnp.any(kl > max_kl): print(f"\nProjection type: {ProjectionClass.__name__}") print(f"KL divergence: {kl}") print(f"Max allowed KL: {max_kl}") print_gaussian_params("Original params", gaussian_params["params"]) print_gaussian_params("Old params", gaussian_params["old_params"]) print_gaussian_params("Projected params", proj_params) assert False, f"KL divergence {kl} exceeds bound {max_kl}" @pytest.mark.parametrize("ProjectionClass", [KLProjection, FrobeniusProjection]) def test_full_covariance_projection(ProjectionClass): """Test projections with full covariance""" try: import cpp_projection except ImportError: pytest.skip("cpp_projection not available") # Create test parameters loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) scale_tril = jnp.array([ [[0.5, 0.0], [0.1, 0.4]], [[0.6, 0.0], [0.2, 0.3]] ]) params = {"loc": loc, "scale_tril": scale_tril} old_loc = jnp.zeros_like(loc) old_scale_tril = jnp.array([ [[0.3, 0.0], [0.0, 0.3]], [[0.3, 0.0], [0.0, 0.3]] ]) old_params = {"loc": old_loc, "scale_tril": old_scale_tril} mean_bound, cov_bound = 0.1, 0.1 proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=False, full_cov=True) proj_params = proj.project(params, old_params) # Check basic properties assert proj_params["loc"].shape == loc.shape assert proj_params["scale_tril"].shape == scale_tril.shape assert jnp.all(jnp.isfinite(proj_params["loc"])) assert jnp.all(jnp.isfinite(proj_params["scale_tril"])) # Verify scale_tril is lower triangular upper_tri = jnp.triu(proj_params["scale_tril"], k=1) assert jnp.allclose(upper_tri, 0.0, atol=1e-5) # Verify positive definiteness cov = jnp.matmul(proj_params["scale_tril"], jnp.swapaxes(proj_params["scale_tril"], -1, -2)) eigvals = jnp.linalg.eigvalsh(cov) assert jnp.all(eigvals > 0) # Check trust region loss computation works tr_loss = proj.get_trust_region_loss(params, proj_params) assert jnp.isfinite(tr_loss) # Only check KL bounds for KL projection if ProjectionClass in [KLProjection]: kl = compute_gaussian_kl(proj_params, old_params, full_cov=True) max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin assert jnp.all(kl <= max_kl), f"KL divergence {kl} exceeds bound {max_kl}" def test_contextual_vs_noncontextual(): """Test difference between contextual and non-contextual standard deviations""" loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) scale = jnp.array([[0.5, 0.6], [0.7, 0.8]]) params = {"loc": loc, "scale": scale} old_loc = jnp.zeros_like(loc) old_scale = jnp.ones_like(scale) * 0.3 old_params = {"loc": old_loc, "scale": old_scale} # Test with contextual=False proj_noncontextual = FrobeniusProjection( mean_bound=0.1, cov_bound=0.1, contextual_std=False, full_cov=False ) proj_params_nonctx = proj_noncontextual.project(params, old_params) # Non-contextual should have same scale for all batch elements scale_diff = jnp.std(proj_params_nonctx["scale"], axis=0) assert jnp.allclose(scale_diff, 0.0, atol=1e-5) @pytest.mark.parametrize("ProjectionClass,needs_cpp", [ (KLProjection, True), (WassersteinProjection, False), (FrobeniusProjection, False), (IdentityProjection, False) ]) def test_gradient_flow(ProjectionClass, needs_cpp): """Test that gradients can flow through all projections""" if needs_cpp: try: import cpp_projection except ImportError: pytest.skip("cpp_projection not available") # Create test parameters loc = jnp.array([[1.0, -1.0]]) scale = jnp.array([[0.5, 0.5]]) params = {"loc": loc, "scale": scale} old_params = { "loc": jnp.zeros_like(loc), "scale": jnp.ones_like(scale) * 0.3 } proj = ProjectionClass(mean_bound=0.1, cov_bound=0.1) # Test gradient through projection def proj_loss_fn(p): proj_p = proj.project(p, old_params) return jnp.sum(proj_p["loc"]**2 + proj_p["scale"]**2) proj_grads = jax.grad(proj_loss_fn)(params) # Test gradient through trust region loss def tr_loss_fn(p): proj_p = proj.project(p, old_params) return proj.get_trust_region_loss(p, proj_p) tr_grads = jax.grad(tr_loss_fn)(params) # Check gradients exist and have correct shape for grads in [proj_grads, tr_grads]: assert "loc" in grads and "scale" in grads assert grads["loc"].shape == params["loc"].shape assert grads["scale"].shape == params["scale"].shape assert jnp.all(jnp.isfinite(grads["loc"])) assert jnp.all(jnp.isfinite(grads["scale"]))