itpal_jax/README.md
2024-12-21 17:50:26 +01:00

2.4 KiB




ITPAL JAX

JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by ITPAL's C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO.

Features

  • Multiple projection types:
    • KL (Kullback-Leibler divergence)
    • Wasserstein (only diagonal covariance)
    • Frobenius (wip, not tested)
    • Identity (no projection)
  • Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition)
  • Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same)

Installation

python3.10 -m venv .venv # newer versions have issues with ITPAL...
source .venv/bin/activate
pip install -e .
# install itpal (by e.g. copying the .so file into site packages for the venv)

Usage

import jax.numpy as jnp
from itpal_jax import KLProjection

# Create projector
proj = KLProjection(
    mean_bound=0.1,        # KL bound for mean
    cov_bound=0.1,         # KL bound for covariance
    contextual_std=True,   # Whether to use contextual standard deviations
    full_cov=False         # Whether to use full covariance matrix
)

# Project Gaussian parameters
new_params = {
    "loc": jnp.array([[1.0, -1.0]]),      # mean
    "scale": jnp.array([[0.5, 0.5]])      # standard deviations
}
old_params = {
    "loc": jnp.zeros((1, 2)),
    "scale": jnp.ones((1, 2)) * 0.3
}

# Get projected parameters
proj_params = proj.project(new_params, old_params)

# Get trust region loss
loss = proj.get_trust_region_loss(new_params, proj_params)

Testing

pytest tests/test_projections.py

Note: The test suite verifies:

  1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
  2. KL bounds are actually (approximately) met for:
    • KL projection (both diagonal and full covariance)
    • Wasserstein projection (diagonal covariance only)
  3. Gradients can be computed through all projections:
    • Both through projection operation and trust region loss
    • Gradients have correct shapes and are finite