Compare commits

..

No commits in common. "de2b9a10d6f807a40821fbea42aca3d4eb2356d7" and "44eb3335ff413509df52a380862dd210a2bf669d" have entirely different histories.

5 changed files with 79 additions and 130 deletions

View File

@ -65,7 +65,9 @@ pytest tests/test_projections.py
*Note*: The test suite verifies: *Note*: The test suite verifies:
1. All projections run without errors and maintain basic properties (shapes, positive definiteness) 1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance) 2. KL bounds are actually (approximately) met for:
- KL projection (both diagonal and full covariance)
- Wasserstein projection (diagonal covariance only)
3. Gradients can be computed through all projections: 3. Gradients can be computed through all projections:
- Both through projection operation and trust region loss - Both through projection operation and trust region loss
- Gradients have correct shapes and are finite - Gradients have correct shapes and are finite

View File

@ -1,7 +1,6 @@
import jax.numpy as jnp import jax.numpy as jnp
from .base_projection import BaseProjection from .base_projection import BaseProjection
from typing import Dict from typing import Dict
import jax
class FrobeniusProjection(BaseProjection): class FrobeniusProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -65,26 +64,12 @@ class FrobeniusProjection(BaseProjection):
old_mean, old_cov = q old_mean, old_cov = q
if self.scale_prec: if self.scale_prec:
# Mahalanobis distance for mean prec_old = jnp.linalg.inv(old_cov)
diff = mean - old_mean mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
if old_cov.ndim == mean.ndim: # diagonal case cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1)
else:
solved = jax.scipy.linalg.solve_triangular(
old_cov, diff[..., None], lower=True
)
mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
else: else:
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1) mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
# Frobenius norm for covariance
if cov.ndim == mean.ndim: # diagonal case
diff = old_cov - cov
cov_part = jnp.sum(jnp.square(diff), axis=-1)
else:
diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \
jnp.matmul(cov, jnp.swapaxes(cov, -1, -2))
cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1))
return mean_part, cov_part return mean_part, cov_part
@ -98,20 +83,22 @@ class FrobeniusProjection(BaseProjection):
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray: cov_part: jnp.ndarray) -> jnp.ndarray:
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1] batch_shape = cov.shape[:-2]
cov_mask = cov_part > self.cov_bound cov_mask = cov_part > self.cov_bound
eta = jnp.ones(batch_shape, dtype=cov.dtype) eta = jnp.ones(batch_shape, dtype=cov.dtype)
eta = jnp.where(cov_mask, eta = jnp.where(cov_mask,
jnp.sqrt(cov_part / self.cov_bound) - 1., jnp.sqrt(cov_part / self.cov_bound) - 1.,
eta) eta)
eta = jnp.maximum(-eta, eta) eta = jnp.maximum(-eta, eta)
if self.full_cov: if self.full_cov:
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \ new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
(1. + eta + 1e-16)[..., None, None]
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
return jnp.linalg.cholesky(proj_cov)
else: else:
# For diagonal case, simple broadcasting
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None] new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
new_cov, cov)
return proj_cov

View File

@ -13,24 +13,6 @@ from .exception_projection import makeExceptionProjection
MAX_EVAL = 1000 MAX_EVAL = 1000
# Cache for projection operators
_diag_proj_op = None
_full_proj_op = None
def _get_diag_proj_op(batch_shape, dim):
global _diag_proj_op
if _diag_proj_op is None:
_diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _diag_proj_op
def _get_full_proj_op(batch_shape, dim):
global _full_proj_op
if _full_proj_op is None:
_full_proj_op = cpp_projection.BatchedCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _full_proj_op
class KLProjection(BaseProjection): class KLProjection(BaseProjection):
"""KL divergence-based projection for Gaussian policies. """KL divergence-based projection for Gaussian policies.
@ -184,49 +166,6 @@ class KLProjection(BaseProjection):
if key not in policy_params or key not in old_policy_params: if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters") raise KeyError(f"Missing required key '{key}' in policy parameters")
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype)
p_op = _get_diag_proj_op(batch_shape, dim)
try:
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
except:
proj_cov = cov_np # Return input on failure
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = _get_diag_proj_op(batch_shape, dim)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(3,)) @partial(jax.custom_vjp, nondiff_argnums=(3,))
def project_full_covariance(cov, chol, old_chol, eps_cov): def project_full_covariance(cov, chol, old_chol, eps_cov):
"""JAX wrapper for C++ full covariance projection""" """JAX wrapper for C++ full covariance projection"""
@ -240,7 +179,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov):
eps = eps_cov * np.ones(batch_shape) eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly # Create C++ projection operator directly
p_op = _get_full_proj_op(batch_shape, dim) p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection # Run C++ projection
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np) proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
@ -264,7 +203,7 @@ def project_full_covariance_bwd(eps_cov, res, g):
dim = g_np.shape[-1] dim = g_np.shape[-1]
# Get C++ projection operator # Get C++ projection operator
p_op = _get_full_proj_op(batch_shape, dim) p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass # Run C++ backward pass
grad_cov = p_op.backward(g_np) grad_cov = p_op.backward(g_np)
@ -275,5 +214,48 @@ def project_full_covariance_bwd(eps_cov, res, g):
# Register VJP rule for full covariance projection # Register VJP rule for full covariance projection
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd) project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
# Convert JAX arrays to numpy for C++ function
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
batch_shape = cov_np.shape[0]
dim = cov_np.shape[-1]
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
# Convert back to JAX array
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
if not cpp_projection_available: if not cpp_projection_available:
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.") KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")

View File

@ -1,7 +1,6 @@
import jax.numpy as jnp import jax.numpy as jnp
from .base_projection import BaseProjection from .base_projection import BaseProjection
from typing import Dict, Tuple from typing import Dict, Tuple
import jax
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray: def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
""" """
@ -23,8 +22,7 @@ class WassersteinProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray], def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
if self.full_cov: assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.")
mean = policy_params["loc"] # shape: (batch_size, dim) mean = policy_params["loc"] # shape: (batch_size, dim)
old_mean = old_policy_params["loc"] old_mean = old_policy_params["loc"]
@ -75,47 +73,27 @@ class WassersteinProjection(BaseProjection):
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray: scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters using multiplicative update. """Project scale parameters (standard deviations for diagonal case)"""
diff = scale - old_scale
norm = jnp.sqrt(scale_part)
Args: if scale.ndim == 2: # Batched scale
scale: Current scale/sqrt of covariance norm = norm[..., None]
old_scale: Previous scale/sqrt of covariance
scale_part: W2 distance between scales
Returns: return jnp.where(norm > self.cov_bound,
Projected scale that satisfies the trust region constraint old_scale + diff * self.cov_bound / norm,
""" scale)
# Check if projection needed
cov_mask = scale_part > self.cov_bound
# Compute eta (multiplier for the update)
batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1]
eta = jnp.ones(batch_shape, dtype=scale.dtype)
eta = jnp.where(cov_mask,
jnp.sqrt(scale_part / self.cov_bound) - 1.,
eta)
eta = jnp.maximum(-eta, eta)
# Multiplicative update with correct broadcasting
if scale.ndim > 2: # Full covariance case
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
(1. + eta + 1e-16)[..., None, None]
mask = cov_mask[..., None, None]
else: # Diagonal case
new_scale = (scale + eta[..., None] * old_scale) / \
(1. + eta + 1e-16)[..., None]
mask = cov_mask[..., None]
return jnp.where(mask, new_scale, scale)
def _gaussian_wasserstein(self, p, q): def _gaussian_wasserstein(self, p, q):
mean, scale = p mean, scale = p
mean_other, scale_other = q mean_other, scale_other = q
# Euclidean distance for mean part (we're in diagonal case) # Keep batch dimension by only summing over feature dimension
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
# Standard W2 objective for covariance (diagonal case) if scale.ndim == mean.ndim: # Batched scale
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1) cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
else: # Non-contextual scale (single scale for all batches)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
return mean_part, cov_part return mean_part, cov_part

View File

@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
assert jnp.all(jnp.isfinite(proj_params["scale"])) assert jnp.all(jnp.isfinite(proj_params["scale"]))
assert jnp.all(proj_params["scale"] > 0) assert jnp.all(proj_params["scale"] > 0)
# Only check KL bounds for KL projection # Only check KL bounds for KL projection (and W2, which should approx hold as well)
if ProjectionClass in [KLProjection]: if ProjectionClass in [KLProjection, WassersteinProjection]:
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"]) kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin