itpal_jax/tests/test_projections.py
2024-12-21 17:48:36 +01:00

227 lines
8.6 KiB
Python

import jax
import jax.numpy as jnp
import pytest
from itpal_jax import (
KLProjection,
WassersteinProjection,
FrobeniusProjection,
IdentityProjection
)
from typing import Dict
@pytest.fixture
def gaussian_params():
"""Create test Gaussian parameters"""
loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) # batch_size=2, dim=2
scale = jnp.array([[0.5, 0.5], [0.5, 0.5]])
old_loc = jnp.zeros_like(loc)
old_scale = jnp.ones_like(scale) * 0.3
return {
"params": {"loc": loc, "scale": scale},
"old_params": {"loc": old_loc, "scale": old_scale}
}
def compute_gaussian_kl(p_params: Dict[str, jnp.ndarray],
q_params: Dict[str, jnp.ndarray],
full_cov: bool = False) -> jnp.ndarray:
"""Compute KL divergence between two Gaussians."""
mean1, mean2 = p_params["loc"], q_params["loc"]
k = mean1.shape[-1]
if full_cov:
scale1, scale2 = p_params["scale_tril"], q_params["scale_tril"]
# Compute KL for full covariance case
cov1 = jnp.matmul(scale1, jnp.swapaxes(scale1, -1, -2))
cov2 = jnp.matmul(scale2, jnp.swapaxes(scale2, -1, -2))
# Solve L x = v where L is Cholesky of cov2
solved = jax.scipy.linalg.solve_triangular(scale2, scale1, lower=True)
trace_term = jnp.sum(solved ** 2, axis=(-2, -1))
# Log det terms
logdet1 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale1, axis1=-2, axis2=-1)), axis=-1)
logdet2 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale2, axis1=-2, axis2=-1)), axis=-1)
# Mahalanobis term
diff = mean1 - mean2
maha = jnp.sum(jnp.square(jax.scipy.linalg.solve_triangular(
scale2, diff[..., None], lower=True
).squeeze(-1)), axis=-1)
else:
scale1, scale2 = p_params["scale"], q_params["scale"]
# Compute KL for diagonal case
trace_term = jnp.sum((scale1 / scale2) ** 2, axis=-1)
logdet1 = 2 * jnp.sum(jnp.log(scale1), axis=-1)
logdet2 = 2 * jnp.sum(jnp.log(scale2), axis=-1)
maha = jnp.sum(jnp.square((mean1 - mean2) / scale2), axis=-1)
return 0.5 * (trace_term - k + logdet2 - logdet1 + maha)
def print_gaussian_params(prefix: str, params: Dict[str, jnp.ndarray]):
"""Pretty print Gaussian parameters"""
print(f"\n{prefix}:")
print(f" loc: {params['loc']}")
if "scale" in params:
print(f" scale: {params['scale']}")
if "scale_tril" in params:
print(f" scale_tril:\n{params['scale_tril']}")
@pytest.mark.parametrize("ProjectionClass,needs_cpp", [
(KLProjection, True),
(WassersteinProjection, False),
(FrobeniusProjection, False),
(IdentityProjection, False)
])
def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
"""Test projections with diagonal covariance"""
if needs_cpp:
try:
import cpp_projection
except ImportError:
pytest.skip("cpp_projection not available")
mean_bound, cov_bound = 0.1, 0.1
proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound,
contextual_std=False, full_cov=False)
proj_params = proj.project(gaussian_params["params"], gaussian_params["old_params"])
# Check basic properties
assert proj_params["loc"].shape == gaussian_params["params"]["loc"].shape
assert proj_params["scale"].shape == gaussian_params["params"]["scale"].shape
assert jnp.all(jnp.isfinite(proj_params["loc"]))
assert jnp.all(jnp.isfinite(proj_params["scale"]))
assert jnp.all(proj_params["scale"] > 0)
# Only check KL bounds for KL projection (and W2, which should approx hold as well)
if ProjectionClass in [KLProjection, WassersteinProjection]:
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
if jnp.any(kl > max_kl):
print(f"\nProjection type: {ProjectionClass.__name__}")
print(f"KL divergence: {kl}")
print(f"Max allowed KL: {max_kl}")
print_gaussian_params("Original params", gaussian_params["params"])
print_gaussian_params("Old params", gaussian_params["old_params"])
print_gaussian_params("Projected params", proj_params)
assert False, f"KL divergence {kl} exceeds bound {max_kl}"
@pytest.mark.parametrize("ProjectionClass", [KLProjection, FrobeniusProjection])
def test_full_covariance_projection(ProjectionClass):
"""Test projections with full covariance"""
try:
import cpp_projection
except ImportError:
pytest.skip("cpp_projection not available")
# Create test parameters
loc = jnp.array([[1.0, -1.0], [0.5, -0.5]])
scale_tril = jnp.array([
[[0.5, 0.0], [0.1, 0.4]],
[[0.6, 0.0], [0.2, 0.3]]
])
params = {"loc": loc, "scale_tril": scale_tril}
old_loc = jnp.zeros_like(loc)
old_scale_tril = jnp.array([
[[0.3, 0.0], [0.0, 0.3]],
[[0.3, 0.0], [0.0, 0.3]]
])
old_params = {"loc": old_loc, "scale_tril": old_scale_tril}
mean_bound, cov_bound = 0.1, 0.1
proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound,
contextual_std=False, full_cov=True)
proj_params = proj.project(params, old_params)
# Check basic properties
assert proj_params["loc"].shape == loc.shape
assert proj_params["scale_tril"].shape == scale_tril.shape
assert jnp.all(jnp.isfinite(proj_params["loc"]))
assert jnp.all(jnp.isfinite(proj_params["scale_tril"]))
# Verify scale_tril is lower triangular
upper_tri = jnp.triu(proj_params["scale_tril"], k=1)
assert jnp.allclose(upper_tri, 0.0, atol=1e-5)
# Verify positive definiteness
cov = jnp.matmul(proj_params["scale_tril"], jnp.swapaxes(proj_params["scale_tril"], -1, -2))
eigvals = jnp.linalg.eigvalsh(cov)
assert jnp.all(eigvals > 0)
# Only check KL bounds for KL projection
if ProjectionClass in [KLProjection]:
kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
assert jnp.all(kl <= max_kl), f"KL divergence {kl} exceeds bound {max_kl}"
def test_contextual_vs_noncontextual():
"""Test difference between contextual and non-contextual standard deviations"""
loc = jnp.array([[1.0, -1.0], [0.5, -0.5]])
scale = jnp.array([[0.5, 0.6], [0.7, 0.8]])
params = {"loc": loc, "scale": scale}
old_loc = jnp.zeros_like(loc)
old_scale = jnp.ones_like(scale) * 0.3
old_params = {"loc": old_loc, "scale": old_scale}
# Test with contextual=False
proj_noncontextual = FrobeniusProjection(
mean_bound=0.1, cov_bound=0.1,
contextual_std=False, full_cov=False
)
proj_params_nonctx = proj_noncontextual.project(params, old_params)
# Non-contextual should have same scale for all batch elements
scale_diff = jnp.std(proj_params_nonctx["scale"], axis=0)
assert jnp.allclose(scale_diff, 0.0, atol=1e-5)
@pytest.mark.parametrize("ProjectionClass,needs_cpp", [
(KLProjection, True),
(WassersteinProjection, False),
(FrobeniusProjection, False),
(IdentityProjection, False)
])
def test_gradient_flow(ProjectionClass, needs_cpp):
"""Test that gradients can flow through all projections"""
if needs_cpp:
try:
import cpp_projection
except ImportError:
pytest.skip("cpp_projection not available")
# Create test parameters
loc = jnp.array([[1.0, -1.0]])
scale = jnp.array([[0.5, 0.5]])
params = {"loc": loc, "scale": scale}
old_params = {
"loc": jnp.zeros_like(loc),
"scale": jnp.ones_like(scale) * 0.3
}
proj = ProjectionClass(mean_bound=0.1, cov_bound=0.1)
# Test gradient through projection
def proj_loss_fn(p):
proj_p = proj.project(p, old_params)
return jnp.sum(proj_p["loc"]**2 + proj_p["scale"]**2)
proj_grads = jax.grad(proj_loss_fn)(params)
# Test gradient through trust region loss
def tr_loss_fn(p):
proj_p = proj.project(p, old_params)
return proj.get_trust_region_loss(p, proj_p)
tr_grads = jax.grad(tr_loss_fn)(params)
# Check gradients exist and have correct shape
for grads in [proj_grads, tr_grads]:
assert "loc" in grads and "scale" in grads
assert grads["loc"].shape == params["loc"].shape
assert grads["scale"].shape == params["scale"].shape
assert jnp.all(jnp.isfinite(grads["loc"]))
assert jnp.all(jnp.isfinite(grads["scale"]))