ITPAL JAX

JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by [ITPAL](https://github.com/ALRhub/ITPAL)'s C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO. ## Features - Multiple projection types: - KL (Kullback-Leibler divergence) - Wasserstein (only diagonal covariance) - Frobenius (wip, not tested) - Identity (no projection) - Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition) - Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same) ## Installation ```bash python3.10 -m venv .venv # newer versions have issues with ITPAL... source .venv/bin/activate pip install -e . # install itpal (by e.g. copying the .so file into site packages for the venv) ``` ## Usage ```python import jax.numpy as jnp from itpal_jax import KLProjection # Create projector proj = KLProjection( mean_bound=0.1, # KL bound for mean cov_bound=0.1, # KL bound for covariance contextual_std=True, # Whether to use contextual standard deviations full_cov=False # Whether to use full covariance matrix ) # Project Gaussian parameters new_params = { "loc": jnp.array([[1.0, -1.0]]), # mean "scale": jnp.array([[0.5, 0.5]]) # standard deviations } old_params = { "loc": jnp.zeros((1, 2)), "scale": jnp.ones((1, 2)) * 0.3 } # Get projected parameters proj_params = proj.project(new_params, old_params) # Get trust region loss loss = proj.get_trust_region_loss(new_params, proj_params) ``` ## Testing ```bash pytest tests/test_projections.py ``` *Note*: The test suite verifies: 1. All projections run without errors and maintain basic properties (shapes, positive definiteness) 2. KL bounds are actually (approximately) met for: - KL projection (both diagonal and full covariance) - Wasserstein projection (diagonal covariance only) 3. Gradients can be computed through all projections: - Both through projection operation and trust region loss - Gradients have correct shapes and are finite