ITPAL JAX

JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by [ITPAL](https://github.com/ALRhub/ITPAL)'s C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO. ## Features - Multiple projection types: - KL (Kullback-Leibler divergence) - Wasserstein (only diagonal covariance) - Frobenius (wip, problem with cov projections) - Identity (no projection) - Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition) - Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same) ## Installation ```bash python3.10 -m venv .venv # newer versions have issues with ITPAL... source .venv/bin/activate pip install -e . # install itpal (by e.g. copying the .so file into site packages for the venv) ``` ## Usage ```python import jax.numpy as jnp from itpal_jax import KLProjection # Create projector proj = KLProjection( mean_bound=0.1, # KL bound for mean cov_bound=0.1, # KL bound for covariance contextual_std=True, # Whether to use contextual standard deviations full_cov=False # Whether to use full covariance matrix ) # Project Gaussian parameters new_params = { "loc": jnp.array([[1.0, -1.0]]), # mean "scale": jnp.array([[0.5, 0.5]]) # standard deviations } old_params = { "loc": jnp.zeros((1, 2)), "scale": jnp.ones((1, 2)) * 0.3 } # Get projected parameters proj_params = proj.project(new_params, old_params) # Get trust region loss loss = proj.get_trust_region_loss(new_params, proj_params) ``` ## Testing ```bash pytest tests/test_projections.py ``` *Note*: The test suite verifies: 1. All projections run without errors and maintain basic properties (shapes, positive definiteness) 2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance) 3. Gradients can be computed through all projections: - Both through projection operation and trust region loss - Gradients have correct shapes and are finite