JAX bindings and native implementations of differentiable trust region projections for Gaussian policies.
Go to file
2024-12-21 18:53:27 +01:00
itpal_jax Fixes for contextual KL 2024-12-21 18:53:11 +01:00
tests Also check loss calc works for full cov case 2024-12-21 18:53:27 +01:00
.gitignore Added gitignore 2024-12-21 17:48:25 +01:00
itpal_jax.svg Logo for README 2024-12-21 17:49:26 +01:00
pyproject.toml Updated pyproject.toml 2024-12-21 17:49:35 +01:00
README.md Updated README 2024-12-21 18:31:26 +01:00




ITPAL JAX

JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by ITPAL's C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO.

Features

  • Multiple projection types:
    • KL (Kullback-Leibler divergence)
    • Wasserstein (only diagonal covariance)
    • Frobenius (wip, not tested)
    • Identity (no projection)
  • Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition)
  • Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same)

Installation

python3.10 -m venv .venv # newer versions have issues with ITPAL...
source .venv/bin/activate
pip install -e .
# install itpal (by e.g. copying the .so file into site packages for the venv)

Usage

import jax.numpy as jnp
from itpal_jax import KLProjection

# Create projector
proj = KLProjection(
    mean_bound=0.1,        # KL bound for mean
    cov_bound=0.1,         # KL bound for covariance
    contextual_std=True,   # Whether to use contextual standard deviations
    full_cov=False         # Whether to use full covariance matrix
)

# Project Gaussian parameters
new_params = {
    "loc": jnp.array([[1.0, -1.0]]),      # mean
    "scale": jnp.array([[0.5, 0.5]])      # standard deviations
}
old_params = {
    "loc": jnp.zeros((1, 2)),
    "scale": jnp.ones((1, 2)) * 0.3
}

# Get projected parameters
proj_params = proj.project(new_params, old_params)

# Get trust region loss
loss = proj.get_trust_region_loss(new_params, proj_params)

Testing

pytest tests/test_projections.py

Note: The test suite verifies:

  1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
  2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance)
  3. Gradients can be computed through all projections:
    • Both through projection operation and trust region loss
    • Gradients have correct shapes and are finite