Actual README
This commit is contained in:
parent
e0eb46e14c
commit
6d10292fc4
74
README.md
74
README.md
@ -1,2 +1,72 @@
|
||||
# ITPAL JAX
|
||||
Its bindings into ITPAL, written in/for jax. Thats it. End of README.
|
||||
<h1 align="center">
|
||||
<img src='./itpal_jax.svg' width="250px">
|
||||
<br>
|
||||
<b>ITPAL JAX</b>
|
||||
<br>
|
||||
</h1>
|
||||
|
||||
JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by [ITPAL](https://github.com/ALRhub/ITPAL)'s C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO.
|
||||
|
||||
## Features
|
||||
- Multiple projection types:
|
||||
- KL (Kullback-Leibler divergence)
|
||||
- Wasserstein (only diagonal covariance)
|
||||
- Frobenius (wip, not tested)
|
||||
- Identity (no projection)
|
||||
- Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition)
|
||||
- Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
python3.10 -m venv .venv # newer versions have issues with ITPAL...
|
||||
source .venv/bin/activate
|
||||
pip install -e .
|
||||
# install itpal (by e.g. copying the .so file into site packages for the venv)
|
||||
```
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import jax.numpy as jnp
|
||||
from itpal_jax import KLProjection
|
||||
|
||||
# Create projector
|
||||
proj = KLProjection(
|
||||
mean_bound=0.1, # KL bound for mean
|
||||
cov_bound=0.1, # KL bound for covariance
|
||||
contextual_std=True, # Whether to use contextual standard deviations
|
||||
full_cov=False # Whether to use full covariance matrix
|
||||
)
|
||||
|
||||
# Project Gaussian parameters
|
||||
new_params = {
|
||||
"loc": jnp.array([[1.0, -1.0]]), # mean
|
||||
"scale": jnp.array([[0.5, 0.5]]) # standard deviations
|
||||
}
|
||||
old_params = {
|
||||
"loc": jnp.zeros((1, 2)),
|
||||
"scale": jnp.ones((1, 2)) * 0.3
|
||||
}
|
||||
|
||||
# Get projected parameters
|
||||
proj_params = proj.project(new_params, old_params)
|
||||
|
||||
# Get trust region loss
|
||||
loss = proj.get_trust_region_loss(new_params, proj_params)
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
pytest tests/test_projections.py
|
||||
```
|
||||
|
||||
*Note*: The test suite verifies:
|
||||
|
||||
1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
|
||||
2. KL bounds are actually (approximately) met for:
|
||||
- KL projection (both diagonal and full covariance)
|
||||
- Wasserstein projection (diagonal covariance only)
|
||||
3. Gradients can be computed through all projections:
|
||||
- Both through projection operation and trust region loss
|
||||
- Gradients have correct shapes and are finite
|
Loading…
Reference in New Issue
Block a user