227 lines
8.5 KiB
Python
227 lines
8.5 KiB
Python
import jax
|
|
import jax.numpy as jnp
|
|
import pytest
|
|
from itpal_jax import (
|
|
KLProjection,
|
|
WassersteinProjection,
|
|
FrobeniusProjection,
|
|
IdentityProjection
|
|
)
|
|
from typing import Dict
|
|
|
|
@pytest.fixture
|
|
def gaussian_params():
|
|
"""Create test Gaussian parameters"""
|
|
loc = jnp.array([[1.0, -1.0], [0.5, -0.5]]) # batch_size=2, dim=2
|
|
scale = jnp.array([[0.5, 0.5], [0.5, 0.5]])
|
|
old_loc = jnp.zeros_like(loc)
|
|
old_scale = jnp.ones_like(scale) * 0.3
|
|
|
|
return {
|
|
"params": {"loc": loc, "scale": scale},
|
|
"old_params": {"loc": old_loc, "scale": old_scale}
|
|
}
|
|
|
|
def compute_gaussian_kl(p_params: Dict[str, jnp.ndarray],
|
|
q_params: Dict[str, jnp.ndarray],
|
|
full_cov: bool = False) -> jnp.ndarray:
|
|
"""Compute KL divergence between two Gaussians."""
|
|
mean1, mean2 = p_params["loc"], q_params["loc"]
|
|
k = mean1.shape[-1]
|
|
|
|
if full_cov:
|
|
scale1, scale2 = p_params["scale_tril"], q_params["scale_tril"]
|
|
# Compute KL for full covariance case
|
|
cov1 = jnp.matmul(scale1, jnp.swapaxes(scale1, -1, -2))
|
|
cov2 = jnp.matmul(scale2, jnp.swapaxes(scale2, -1, -2))
|
|
|
|
# Solve L x = v where L is Cholesky of cov2
|
|
solved = jax.scipy.linalg.solve_triangular(scale2, scale1, lower=True)
|
|
trace_term = jnp.sum(solved ** 2, axis=(-2, -1))
|
|
|
|
# Log det terms
|
|
logdet1 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale1, axis1=-2, axis2=-1)), axis=-1)
|
|
logdet2 = 2 * jnp.sum(jnp.log(jnp.diagonal(scale2, axis1=-2, axis2=-1)), axis=-1)
|
|
|
|
# Mahalanobis term
|
|
diff = mean1 - mean2
|
|
maha = jnp.sum(jnp.square(jax.scipy.linalg.solve_triangular(
|
|
scale2, diff[..., None], lower=True
|
|
).squeeze(-1)), axis=-1)
|
|
|
|
else:
|
|
scale1, scale2 = p_params["scale"], q_params["scale"]
|
|
# Compute KL for diagonal case
|
|
trace_term = jnp.sum((scale1 / scale2) ** 2, axis=-1)
|
|
logdet1 = 2 * jnp.sum(jnp.log(scale1), axis=-1)
|
|
logdet2 = 2 * jnp.sum(jnp.log(scale2), axis=-1)
|
|
maha = jnp.sum(jnp.square((mean1 - mean2) / scale2), axis=-1)
|
|
|
|
return 0.5 * (trace_term - k + logdet2 - logdet1 + maha)
|
|
|
|
def print_gaussian_params(prefix: str, params: Dict[str, jnp.ndarray]):
|
|
"""Pretty print Gaussian parameters"""
|
|
print(f"\n{prefix}:")
|
|
print(f" loc: {params['loc']}")
|
|
if "scale" in params:
|
|
print(f" scale: {params['scale']}")
|
|
if "scale_tril" in params:
|
|
print(f" scale_tril:\n{params['scale_tril']}")
|
|
|
|
@pytest.mark.parametrize("ProjectionClass,needs_cpp", [
|
|
(KLProjection, True),
|
|
(WassersteinProjection, False),
|
|
(FrobeniusProjection, False),
|
|
(IdentityProjection, False)
|
|
])
|
|
def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
|
|
"""Test projections with diagonal covariance"""
|
|
if needs_cpp:
|
|
try:
|
|
import cpp_projection
|
|
except ImportError:
|
|
pytest.skip("cpp_projection not available")
|
|
|
|
mean_bound, cov_bound = 0.1, 0.1
|
|
proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound,
|
|
contextual_std=False, full_cov=False)
|
|
proj_params = proj.project(gaussian_params["params"], gaussian_params["old_params"])
|
|
|
|
# Check basic properties
|
|
assert proj_params["loc"].shape == gaussian_params["params"]["loc"].shape
|
|
assert proj_params["scale"].shape == gaussian_params["params"]["scale"].shape
|
|
assert jnp.all(jnp.isfinite(proj_params["loc"]))
|
|
assert jnp.all(jnp.isfinite(proj_params["scale"]))
|
|
assert jnp.all(proj_params["scale"] > 0)
|
|
|
|
# Only check KL bounds for KL projection
|
|
if ProjectionClass in [KLProjection]:
|
|
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
|
|
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
|
|
|
|
if jnp.any(kl > max_kl):
|
|
print(f"\nProjection type: {ProjectionClass.__name__}")
|
|
print(f"KL divergence: {kl}")
|
|
print(f"Max allowed KL: {max_kl}")
|
|
print_gaussian_params("Original params", gaussian_params["params"])
|
|
print_gaussian_params("Old params", gaussian_params["old_params"])
|
|
print_gaussian_params("Projected params", proj_params)
|
|
assert False, f"KL divergence {kl} exceeds bound {max_kl}"
|
|
|
|
@pytest.mark.parametrize("ProjectionClass", [KLProjection, FrobeniusProjection])
|
|
def test_full_covariance_projection(ProjectionClass):
|
|
"""Test projections with full covariance"""
|
|
try:
|
|
import cpp_projection
|
|
except ImportError:
|
|
pytest.skip("cpp_projection not available")
|
|
|
|
# Create test parameters
|
|
loc = jnp.array([[1.0, -1.0], [0.5, -0.5]])
|
|
scale_tril = jnp.array([
|
|
[[0.5, 0.0], [0.1, 0.4]],
|
|
[[0.6, 0.0], [0.2, 0.3]]
|
|
])
|
|
params = {"loc": loc, "scale_tril": scale_tril}
|
|
|
|
old_loc = jnp.zeros_like(loc)
|
|
old_scale_tril = jnp.array([
|
|
[[0.3, 0.0], [0.0, 0.3]],
|
|
[[0.3, 0.0], [0.0, 0.3]]
|
|
])
|
|
old_params = {"loc": old_loc, "scale_tril": old_scale_tril}
|
|
|
|
mean_bound, cov_bound = 0.1, 0.1
|
|
proj = ProjectionClass(mean_bound=mean_bound, cov_bound=cov_bound,
|
|
contextual_std=False, full_cov=True)
|
|
proj_params = proj.project(params, old_params)
|
|
|
|
# Check basic properties
|
|
assert proj_params["loc"].shape == loc.shape
|
|
assert proj_params["scale_tril"].shape == scale_tril.shape
|
|
assert jnp.all(jnp.isfinite(proj_params["loc"]))
|
|
assert jnp.all(jnp.isfinite(proj_params["scale_tril"]))
|
|
|
|
# Verify scale_tril is lower triangular
|
|
upper_tri = jnp.triu(proj_params["scale_tril"], k=1)
|
|
assert jnp.allclose(upper_tri, 0.0, atol=1e-5)
|
|
|
|
# Verify positive definiteness
|
|
cov = jnp.matmul(proj_params["scale_tril"], jnp.swapaxes(proj_params["scale_tril"], -1, -2))
|
|
eigvals = jnp.linalg.eigvalsh(cov)
|
|
assert jnp.all(eigvals > 0)
|
|
|
|
# Only check KL bounds for KL projection
|
|
if ProjectionClass in [KLProjection]:
|
|
kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)
|
|
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
|
|
assert jnp.all(kl <= max_kl), f"KL divergence {kl} exceeds bound {max_kl}"
|
|
|
|
def test_contextual_vs_noncontextual():
|
|
"""Test difference between contextual and non-contextual standard deviations"""
|
|
loc = jnp.array([[1.0, -1.0], [0.5, -0.5]])
|
|
scale = jnp.array([[0.5, 0.6], [0.7, 0.8]])
|
|
params = {"loc": loc, "scale": scale}
|
|
|
|
old_loc = jnp.zeros_like(loc)
|
|
old_scale = jnp.ones_like(scale) * 0.3
|
|
old_params = {"loc": old_loc, "scale": old_scale}
|
|
|
|
# Test with contextual=False
|
|
proj_noncontextual = FrobeniusProjection(
|
|
mean_bound=0.1, cov_bound=0.1,
|
|
contextual_std=False, full_cov=False
|
|
)
|
|
proj_params_nonctx = proj_noncontextual.project(params, old_params)
|
|
|
|
# Non-contextual should have same scale for all batch elements
|
|
scale_diff = jnp.std(proj_params_nonctx["scale"], axis=0)
|
|
assert jnp.allclose(scale_diff, 0.0, atol=1e-5)
|
|
|
|
@pytest.mark.parametrize("ProjectionClass,needs_cpp", [
|
|
(KLProjection, True),
|
|
(WassersteinProjection, False),
|
|
(FrobeniusProjection, False),
|
|
(IdentityProjection, False)
|
|
])
|
|
def test_gradient_flow(ProjectionClass, needs_cpp):
|
|
"""Test that gradients can flow through all projections"""
|
|
if needs_cpp:
|
|
try:
|
|
import cpp_projection
|
|
except ImportError:
|
|
pytest.skip("cpp_projection not available")
|
|
|
|
# Create test parameters
|
|
loc = jnp.array([[1.0, -1.0]])
|
|
scale = jnp.array([[0.5, 0.5]])
|
|
params = {"loc": loc, "scale": scale}
|
|
|
|
old_params = {
|
|
"loc": jnp.zeros_like(loc),
|
|
"scale": jnp.ones_like(scale) * 0.3
|
|
}
|
|
|
|
proj = ProjectionClass(mean_bound=0.1, cov_bound=0.1)
|
|
|
|
# Test gradient through projection
|
|
def proj_loss_fn(p):
|
|
proj_p = proj.project(p, old_params)
|
|
return jnp.sum(proj_p["loc"]**2 + proj_p["scale"]**2)
|
|
|
|
proj_grads = jax.grad(proj_loss_fn)(params)
|
|
|
|
# Test gradient through trust region loss
|
|
def tr_loss_fn(p):
|
|
proj_p = proj.project(p, old_params)
|
|
return proj.get_trust_region_loss(p, proj_p)
|
|
|
|
tr_grads = jax.grad(tr_loss_fn)(params)
|
|
|
|
# Check gradients exist and have correct shape
|
|
for grads in [proj_grads, tr_grads]:
|
|
assert "loc" in grads and "scale" in grads
|
|
assert grads["loc"].shape == params["loc"].shape
|
|
assert grads["scale"].shape == params["scale"].shape
|
|
assert jnp.all(jnp.isfinite(grads["loc"]))
|
|
assert jnp.all(jnp.isfinite(grads["scale"])) |