diff --git a/README.md b/README.md index 44279b3..6c264ea 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,72 @@ -# ITPAL JAX -Its bindings into ITPAL, written in/for jax. Thats it. End of README. +

+ +
+ ITPAL JAX +
+

+ +JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by [ITPAL](https://github.com/ALRhub/ITPAL)'s C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO. + +## Features +- Multiple projection types: + - KL (Kullback-Leibler divergence) + - Wasserstein (only diagonal covariance) + - Frobenius (wip, not tested) + - Identity (no projection) +- Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition) +- Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same) + +## Installation + +```bash +python3.10 -m venv .venv # newer versions have issues with ITPAL... +source .venv/bin/activate +pip install -e . +# install itpal (by e.g. copying the .so file into site packages for the venv) +``` +## Usage + +```python +import jax.numpy as jnp +from itpal_jax import KLProjection + +# Create projector +proj = KLProjection( + mean_bound=0.1, # KL bound for mean + cov_bound=0.1, # KL bound for covariance + contextual_std=True, # Whether to use contextual standard deviations + full_cov=False # Whether to use full covariance matrix +) + +# Project Gaussian parameters +new_params = { + "loc": jnp.array([[1.0, -1.0]]), # mean + "scale": jnp.array([[0.5, 0.5]]) # standard deviations +} +old_params = { + "loc": jnp.zeros((1, 2)), + "scale": jnp.ones((1, 2)) * 0.3 +} + +# Get projected parameters +proj_params = proj.project(new_params, old_params) + +# Get trust region loss +loss = proj.get_trust_region_loss(new_params, proj_params) +``` + +## Testing + +```bash +pytest tests/test_projections.py +``` + +*Note*: The test suite verifies: + +1. All projections run without errors and maintain basic properties (shapes, positive definiteness) +2. KL bounds are actually (approximately) met for: + - KL projection (both diagonal and full covariance) + - Wasserstein projection (diagonal covariance only) +3. Gradients can be computed through all projections: + - Both through projection operation and trust region loss + - Gradients have correct shapes and are finite \ No newline at end of file