Compare commits

..

3 Commits

Author SHA1 Message Date
de2b9a10d6 Updated README 2024-12-21 18:31:26 +01:00
9fb0014a99 Updated tests (no check kl for w2) 2024-12-21 18:31:07 +01:00
8e991ae05b Fixes 2024-12-21 18:31:01 +01:00
5 changed files with 130 additions and 79 deletions

View File

@ -65,9 +65,7 @@ pytest tests/test_projections.py
*Note*: The test suite verifies: *Note*: The test suite verifies:
1. All projections run without errors and maintain basic properties (shapes, positive definiteness) 1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
2. KL bounds are actually (approximately) met for: 2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance)
- KL projection (both diagonal and full covariance)
- Wasserstein projection (diagonal covariance only)
3. Gradients can be computed through all projections: 3. Gradients can be computed through all projections:
- Both through projection operation and trust region loss - Both through projection operation and trust region loss
- Gradients have correct shapes and are finite - Gradients have correct shapes and are finite

View File

@ -1,6 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from .base_projection import BaseProjection from .base_projection import BaseProjection
from typing import Dict from typing import Dict
import jax
class FrobeniusProjection(BaseProjection): class FrobeniusProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -64,12 +65,26 @@ class FrobeniusProjection(BaseProjection):
old_mean, old_cov = q old_mean, old_cov = q
if self.scale_prec: if self.scale_prec:
prec_old = jnp.linalg.inv(old_cov) # Mahalanobis distance for mean
mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1) diff = mean - old_mean
cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1] if old_cov.ndim == mean.ndim: # diagonal case
mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1)
else:
solved = jax.scipy.linalg.solve_triangular(
old_cov, diff[..., None], lower=True
)
mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
else: else:
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1) mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
# Frobenius norm for covariance
if cov.ndim == mean.ndim: # diagonal case
diff = old_cov - cov
cov_part = jnp.sum(jnp.square(diff), axis=-1)
else:
diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \
jnp.matmul(cov, jnp.swapaxes(cov, -1, -2))
cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1))
return mean_part, cov_part return mean_part, cov_part
@ -83,7 +98,7 @@ class FrobeniusProjection(BaseProjection):
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray: cov_part: jnp.ndarray) -> jnp.ndarray:
batch_shape = cov.shape[:-2] batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
cov_mask = cov_part > self.cov_bound cov_mask = cov_part > self.cov_bound
eta = jnp.ones(batch_shape, dtype=cov.dtype) eta = jnp.ones(batch_shape, dtype=cov.dtype)
@ -93,12 +108,10 @@ class FrobeniusProjection(BaseProjection):
eta = jnp.maximum(-eta, eta) eta = jnp.maximum(-eta, eta)
if self.full_cov: if self.full_cov:
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
(1. + eta + 1e-16)[..., None, None]
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
return jnp.linalg.cholesky(proj_cov)
else: else:
# For diagonal case, simple broadcasting
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None] new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
new_cov, cov)
return proj_cov

View File

@ -13,6 +13,24 @@ from .exception_projection import makeExceptionProjection
MAX_EVAL = 1000 MAX_EVAL = 1000
# Cache for projection operators
_diag_proj_op = None
_full_proj_op = None
def _get_diag_proj_op(batch_shape, dim):
global _diag_proj_op
if _diag_proj_op is None:
_diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _diag_proj_op
def _get_full_proj_op(batch_shape, dim):
global _full_proj_op
if _full_proj_op is None:
_full_proj_op = cpp_projection.BatchedCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _full_proj_op
class KLProjection(BaseProjection): class KLProjection(BaseProjection):
"""KL divergence-based projection for Gaussian policies. """KL divergence-based projection for Gaussian policies.
@ -166,6 +184,49 @@ class KLProjection(BaseProjection):
if key not in policy_params or key not in old_policy_params: if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters") raise KeyError(f"Missing required key '{key}' in policy parameters")
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype)
p_op = _get_diag_proj_op(batch_shape, dim)
try:
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
except:
proj_cov = cov_np # Return input on failure
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = _get_diag_proj_op(batch_shape, dim)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(3,)) @partial(jax.custom_vjp, nondiff_argnums=(3,))
def project_full_covariance(cov, chol, old_chol, eps_cov): def project_full_covariance(cov, chol, old_chol, eps_cov):
"""JAX wrapper for C++ full covariance projection""" """JAX wrapper for C++ full covariance projection"""
@ -179,7 +240,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov):
eps = eps_cov * np.ones(batch_shape) eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly # Create C++ projection operator directly
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) p_op = _get_full_proj_op(batch_shape, dim)
# Run C++ projection # Run C++ projection
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np) proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
@ -203,7 +264,7 @@ def project_full_covariance_bwd(eps_cov, res, g):
dim = g_np.shape[-1] dim = g_np.shape[-1]
# Get C++ projection operator # Get C++ projection operator
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) p_op = _get_full_proj_op(batch_shape, dim)
# Run C++ backward pass # Run C++ backward pass
grad_cov = p_op.backward(g_np) grad_cov = p_op.backward(g_np)
@ -214,48 +275,5 @@ def project_full_covariance_bwd(eps_cov, res, g):
# Register VJP rule for full covariance projection # Register VJP rule for full covariance projection
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd) project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
# Convert JAX arrays to numpy for C++ function
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
batch_shape = cov_np.shape[0]
dim = cov_np.shape[-1]
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
# Convert back to JAX array
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
if not cpp_projection_available: if not cpp_projection_available:
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.") KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")

View File

@ -1,6 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from .base_projection import BaseProjection from .base_projection import BaseProjection
from typing import Dict, Tuple from typing import Dict, Tuple
import jax
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray: def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
""" """
@ -22,7 +23,8 @@ class WassersteinProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray], def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance" if self.full_cov:
print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.")
mean = policy_params["loc"] # shape: (batch_size, dim) mean = policy_params["loc"] # shape: (batch_size, dim)
old_mean = old_policy_params["loc"] old_mean = old_policy_params["loc"]
@ -73,27 +75,47 @@ class WassersteinProjection(BaseProjection):
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray: scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters (standard deviations for diagonal case)""" """Project scale parameters using multiplicative update.
diff = scale - old_scale
norm = jnp.sqrt(scale_part)
if scale.ndim == 2: # Batched scale Args:
norm = norm[..., None] scale: Current scale/sqrt of covariance
old_scale: Previous scale/sqrt of covariance
scale_part: W2 distance between scales
return jnp.where(norm > self.cov_bound, Returns:
old_scale + diff * self.cov_bound / norm, Projected scale that satisfies the trust region constraint
scale) """
# Check if projection needed
cov_mask = scale_part > self.cov_bound
# Compute eta (multiplier for the update)
batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1]
eta = jnp.ones(batch_shape, dtype=scale.dtype)
eta = jnp.where(cov_mask,
jnp.sqrt(scale_part / self.cov_bound) - 1.,
eta)
eta = jnp.maximum(-eta, eta)
# Multiplicative update with correct broadcasting
if scale.ndim > 2: # Full covariance case
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
(1. + eta + 1e-16)[..., None, None]
mask = cov_mask[..., None, None]
else: # Diagonal case
new_scale = (scale + eta[..., None] * old_scale) / \
(1. + eta + 1e-16)[..., None]
mask = cov_mask[..., None]
return jnp.where(mask, new_scale, scale)
def _gaussian_wasserstein(self, p, q): def _gaussian_wasserstein(self, p, q):
mean, scale = p mean, scale = p
mean_other, scale_other = q mean_other, scale_other = q
# Keep batch dimension by only summing over feature dimension # Euclidean distance for mean part (we're in diagonal case)
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,) mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
if scale.ndim == mean.ndim: # Batched scale # Standard W2 objective for covariance (diagonal case)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1) cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
else: # Non-contextual scale (single scale for all batches)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
return mean_part, cov_part return mean_part, cov_part

View File

@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
assert jnp.all(jnp.isfinite(proj_params["scale"])) assert jnp.all(jnp.isfinite(proj_params["scale"]))
assert jnp.all(proj_params["scale"] > 0) assert jnp.all(proj_params["scale"] > 0)
# Only check KL bounds for KL projection (and W2, which should approx hold as well) # Only check KL bounds for KL projection
if ProjectionClass in [KLProjection, WassersteinProjection]: if ProjectionClass in [KLProjection]:
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"]) kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin