Actual README
This commit is contained in:
		
							parent
							
								
									e0eb46e14c
								
							
						
					
					
						commit
						6d10292fc4
					
				
							
								
								
									
										74
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										74
									
								
								README.md
									
									
									
									
									
								
							@ -1,2 +1,72 @@
 | 
				
			|||||||
# ITPAL JAX
 | 
					<h1 align="center">
 | 
				
			||||||
Its bindings into ITPAL, written in/for jax. Thats it. End of README.
 | 
					  <img src='./itpal_jax.svg' width="250px">
 | 
				
			||||||
 | 
					  <br>
 | 
				
			||||||
 | 
					  <b>ITPAL JAX</b>
 | 
				
			||||||
 | 
					  <br>
 | 
				
			||||||
 | 
					</h1>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					JAX bindings and native implementations of differentiable trust region projections for Gaussian policies. The KL projection is handled by [ITPAL](https://github.com/ALRhub/ITPAL)'s C++ implementation, while Wasserstein and Frobenius projections are implemented in JAX. These projections provide exact solutions for trust region constraints, unlike approximate methods like PPO.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Features
 | 
				
			||||||
 | 
					- Multiple projection types:
 | 
				
			||||||
 | 
					  - KL (Kullback-Leibler divergence)
 | 
				
			||||||
 | 
					  - Wasserstein (only diagonal covariance)
 | 
				
			||||||
 | 
					  - Frobenius (wip, not tested)
 | 
				
			||||||
 | 
					  - Identity (no projection)
 | 
				
			||||||
 | 
					- Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition)
 | 
				
			||||||
 | 
					- Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Installation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python3.10 -m venv .venv # newer versions have issues with ITPAL...
 | 
				
			||||||
 | 
					source .venv/bin/activate
 | 
				
			||||||
 | 
					pip install -e .
 | 
				
			||||||
 | 
					# install itpal (by e.g. copying the .so file into site packages for the venv)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					## Usage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import jax.numpy as jnp
 | 
				
			||||||
 | 
					from itpal_jax import KLProjection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Create projector
 | 
				
			||||||
 | 
					proj = KLProjection(
 | 
				
			||||||
 | 
					    mean_bound=0.1,        # KL bound for mean
 | 
				
			||||||
 | 
					    cov_bound=0.1,         # KL bound for covariance
 | 
				
			||||||
 | 
					    contextual_std=True,   # Whether to use contextual standard deviations
 | 
				
			||||||
 | 
					    full_cov=False         # Whether to use full covariance matrix
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Project Gaussian parameters
 | 
				
			||||||
 | 
					new_params = {
 | 
				
			||||||
 | 
					    "loc": jnp.array([[1.0, -1.0]]),      # mean
 | 
				
			||||||
 | 
					    "scale": jnp.array([[0.5, 0.5]])      # standard deviations
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					old_params = {
 | 
				
			||||||
 | 
					    "loc": jnp.zeros((1, 2)),
 | 
				
			||||||
 | 
					    "scale": jnp.ones((1, 2)) * 0.3
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Get projected parameters
 | 
				
			||||||
 | 
					proj_params = proj.project(new_params, old_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Get trust region loss
 | 
				
			||||||
 | 
					loss = proj.get_trust_region_loss(new_params, proj_params)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Testing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					pytest tests/test_projections.py
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					*Note*: The test suite verifies:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
 | 
				
			||||||
 | 
					2. KL bounds are actually (approximately) met for:
 | 
				
			||||||
 | 
					   - KL projection (both diagonal and full covariance)
 | 
				
			||||||
 | 
					   - Wasserstein projection (diagonal covariance only)
 | 
				
			||||||
 | 
					3. Gradients can be computed through all projections:
 | 
				
			||||||
 | 
					   - Both through projection operation and trust region loss
 | 
				
			||||||
 | 
					   - Gradients have correct shapes and are finite
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user