2024-12-21 17:49:10 +01:00
< h1 align = "center" >
< br >
2024-12-21 17:50:26 +01:00
< img src = './itpal_jax.svg' width = "250px" >
< br > < br >
2024-12-21 17:49:10 +01:00
< b > ITPAL JAX< / b >
2024-12-21 17:50:26 +01:00
< br > < br >
2024-12-21 17:49:10 +01:00
< / h1 >
JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by [ITPAL ](https://github.com/ALRhub/ITPAL )'s C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO.
## Features
- Multiple projection types:
- KL (Kullback-Leibler divergence)
- Wasserstein (only diagonal covariance)
2024-12-21 18:53:44 +01:00
- Frobenius (wip, problem with cov projections)
2024-12-21 17:49:10 +01:00
- Identity (no projection)
- Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition)
- Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same)
## Installation
```bash
python3.10 -m venv .venv # newer versions have issues with ITPAL...
source .venv/bin/activate
pip install -e .
# install itpal (by e.g. copying the .so file into site packages for the venv)
```
## Usage
```python
import jax.numpy as jnp
from itpal_jax import KLProjection
# Create projector
proj = KLProjection(
mean_bound=0.1, # KL bound for mean
cov_bound=0.1, # KL bound for covariance
contextual_std=True, # Whether to use contextual standard deviations
full_cov=False # Whether to use full covariance matrix
)
# Project Gaussian parameters
new_params = {
"loc": jnp.array([[1.0, -1.0]]), # mean
"scale": jnp.array([[0.5, 0.5]]) # standard deviations
}
old_params = {
"loc": jnp.zeros((1, 2)),
"scale": jnp.ones((1, 2)) * 0.3
}
# Get projected parameters
proj_params = proj.project(new_params, old_params)
# Get trust region loss
loss = proj.get_trust_region_loss(new_params, proj_params)
```
## Testing
```bash
pytest tests/test_projections.py
```
*Note*: The test suite verifies:
1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
2024-12-21 18:31:26 +01:00
2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance)
2024-12-21 17:49:10 +01:00
3. Gradients can be computed through all projections:
- Both through projection operation and trust region loss
- Gradients have correct shapes and are finite