Compare commits

..

No commits in common. "de2b9a10d6f807a40821fbea42aca3d4eb2356d7" and "44eb3335ff413509df52a380862dd210a2bf669d" have entirely different histories.

5 changed files with 79 additions and 130 deletions

View File

@ -65,7 +65,9 @@ pytest tests/test_projections.py
*Note*: The test suite verifies:
1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance)
2. KL bounds are actually (approximately) met for:
- KL projection (both diagonal and full covariance)
- Wasserstein projection (diagonal covariance only)
3. Gradients can be computed through all projections:
- Both through projection operation and trust region loss
- Gradients have correct shapes and are finite

View File

@ -1,7 +1,6 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict
import jax
class FrobeniusProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -65,26 +64,12 @@ class FrobeniusProjection(BaseProjection):
old_mean, old_cov = q
if self.scale_prec:
# Mahalanobis distance for mean
diff = mean - old_mean
if old_cov.ndim == mean.ndim: # diagonal case
mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1)
else:
solved = jax.scipy.linalg.solve_triangular(
old_cov, diff[..., None], lower=True
)
mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
prec_old = jnp.linalg.inv(old_cov)
mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
else:
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
# Frobenius norm for covariance
if cov.ndim == mean.ndim: # diagonal case
diff = old_cov - cov
cov_part = jnp.sum(jnp.square(diff), axis=-1)
else:
diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \
jnp.matmul(cov, jnp.swapaxes(cov, -1, -2))
cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1))
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
return mean_part, cov_part
@ -98,20 +83,22 @@ class FrobeniusProjection(BaseProjection):
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
batch_shape = cov.shape[:-2]
cov_mask = cov_part > self.cov_bound
eta = jnp.ones(batch_shape, dtype=cov.dtype)
eta = jnp.where(cov_mask,
jnp.sqrt(cov_part / self.cov_bound) - 1.,
eta)
jnp.sqrt(cov_part / self.cov_bound) - 1.,
eta)
eta = jnp.maximum(-eta, eta)
if self.full_cov:
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
(1. + eta + 1e-16)[..., None, None]
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
return jnp.linalg.cholesky(proj_cov)
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
else:
# For diagonal case, simple broadcasting
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
new_cov, cov)
return proj_cov

View File

@ -13,24 +13,6 @@ from .exception_projection import makeExceptionProjection
MAX_EVAL = 1000
# Cache for projection operators
_diag_proj_op = None
_full_proj_op = None
def _get_diag_proj_op(batch_shape, dim):
global _diag_proj_op
if _diag_proj_op is None:
_diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _diag_proj_op
def _get_full_proj_op(batch_shape, dim):
global _full_proj_op
if _full_proj_op is None:
_full_proj_op = cpp_projection.BatchedCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _full_proj_op
class KLProjection(BaseProjection):
"""KL divergence-based projection for Gaussian policies.
@ -184,49 +166,6 @@ class KLProjection(BaseProjection):
if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters")
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype)
p_op = _get_diag_proj_op(batch_shape, dim)
try:
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
except:
proj_cov = cov_np # Return input on failure
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = _get_diag_proj_op(batch_shape, dim)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def project_full_covariance(cov, chol, old_chol, eps_cov):
"""JAX wrapper for C++ full covariance projection"""
@ -240,7 +179,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov):
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = _get_full_proj_op(batch_shape, dim)
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
@ -264,7 +203,7 @@ def project_full_covariance_bwd(eps_cov, res, g):
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = _get_full_proj_op(batch_shape, dim)
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
@ -275,5 +214,48 @@ def project_full_covariance_bwd(eps_cov, res, g):
# Register VJP rule for full covariance projection
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
# Convert JAX arrays to numpy for C++ function
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
batch_shape = cov_np.shape[0]
dim = cov_np.shape[-1]
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
# Convert back to JAX array
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
if not cpp_projection_available:
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")

View File

@ -1,7 +1,6 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict, Tuple
import jax
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
"""
@ -23,8 +22,7 @@ class WassersteinProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
if self.full_cov:
print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.")
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
mean = policy_params["loc"] # shape: (batch_size, dim)
old_mean = old_policy_params["loc"]
@ -75,47 +73,27 @@ class WassersteinProjection(BaseProjection):
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters using multiplicative update.
"""Project scale parameters (standard deviations for diagonal case)"""
diff = scale - old_scale
norm = jnp.sqrt(scale_part)
Args:
scale: Current scale/sqrt of covariance
old_scale: Previous scale/sqrt of covariance
scale_part: W2 distance between scales
if scale.ndim == 2: # Batched scale
norm = norm[..., None]
Returns:
Projected scale that satisfies the trust region constraint
"""
# Check if projection needed
cov_mask = scale_part > self.cov_bound
# Compute eta (multiplier for the update)
batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1]
eta = jnp.ones(batch_shape, dtype=scale.dtype)
eta = jnp.where(cov_mask,
jnp.sqrt(scale_part / self.cov_bound) - 1.,
eta)
eta = jnp.maximum(-eta, eta)
# Multiplicative update with correct broadcasting
if scale.ndim > 2: # Full covariance case
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
(1. + eta + 1e-16)[..., None, None]
mask = cov_mask[..., None, None]
else: # Diagonal case
new_scale = (scale + eta[..., None] * old_scale) / \
(1. + eta + 1e-16)[..., None]
mask = cov_mask[..., None]
return jnp.where(mask, new_scale, scale)
return jnp.where(norm > self.cov_bound,
old_scale + diff * self.cov_bound / norm,
scale)
def _gaussian_wasserstein(self, p, q):
mean, scale = p
mean_other, scale_other = q
# Euclidean distance for mean part (we're in diagonal case)
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
# Keep batch dimension by only summing over feature dimension
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
# Standard W2 objective for covariance (diagonal case)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
if scale.ndim == mean.ndim: # Batched scale
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
else: # Non-contextual scale (single scale for all batches)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
return mean_part, cov_part

View File

@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
assert jnp.all(jnp.isfinite(proj_params["scale"]))
assert jnp.all(proj_params["scale"] > 0)
# Only check KL bounds for KL projection
if ProjectionClass in [KLProjection]:
# Only check KL bounds for KL projection (and W2, which should approx hold as well)
if ProjectionClass in [KLProjection, WassersteinProjection]:
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin