Compare commits
3 Commits
44eb3335ff
...
de2b9a10d6
Author | SHA1 | Date | |
---|---|---|---|
de2b9a10d6 | |||
9fb0014a99 | |||
8e991ae05b |
@ -65,9 +65,7 @@ pytest tests/test_projections.py
|
||||
*Note*: The test suite verifies:
|
||||
|
||||
1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
|
||||
2. KL bounds are actually (approximately) met for:
|
||||
- KL projection (both diagonal and full covariance)
|
||||
- Wasserstein projection (diagonal covariance only)
|
||||
2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance)
|
||||
3. Gradients can be computed through all projections:
|
||||
- Both through projection operation and trust region loss
|
||||
- Gradients have correct shapes and are finite
|
@ -1,6 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
import jax
|
||||
|
||||
class FrobeniusProjection(BaseProjection):
|
||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||
@ -64,12 +65,26 @@ class FrobeniusProjection(BaseProjection):
|
||||
old_mean, old_cov = q
|
||||
|
||||
if self.scale_prec:
|
||||
prec_old = jnp.linalg.inv(old_cov)
|
||||
mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
|
||||
cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
|
||||
# Mahalanobis distance for mean
|
||||
diff = mean - old_mean
|
||||
if old_cov.ndim == mean.ndim: # diagonal case
|
||||
mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1)
|
||||
else:
|
||||
solved = jax.scipy.linalg.solve_triangular(
|
||||
old_cov, diff[..., None], lower=True
|
||||
)
|
||||
mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
|
||||
else:
|
||||
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
|
||||
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
|
||||
|
||||
# Frobenius norm for covariance
|
||||
if cov.ndim == mean.ndim: # diagonal case
|
||||
diff = old_cov - cov
|
||||
cov_part = jnp.sum(jnp.square(diff), axis=-1)
|
||||
else:
|
||||
diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \
|
||||
jnp.matmul(cov, jnp.swapaxes(cov, -1, -2))
|
||||
cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1))
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
@ -83,7 +98,7 @@ class FrobeniusProjection(BaseProjection):
|
||||
|
||||
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
|
||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||
batch_shape = cov.shape[:-2]
|
||||
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
|
||||
cov_mask = cov_part > self.cov_bound
|
||||
|
||||
eta = jnp.ones(batch_shape, dtype=cov.dtype)
|
||||
@ -93,12 +108,10 @@ class FrobeniusProjection(BaseProjection):
|
||||
eta = jnp.maximum(-eta, eta)
|
||||
|
||||
if self.full_cov:
|
||||
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
||||
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
|
||||
(1. + eta + 1e-16)[..., None, None]
|
||||
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
|
||||
return jnp.linalg.cholesky(proj_cov)
|
||||
else:
|
||||
# For diagonal case, simple broadcasting
|
||||
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
||||
|
||||
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
|
||||
new_cov, cov)
|
||||
|
||||
return proj_cov
|
||||
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)
|
@ -13,6 +13,24 @@ from .exception_projection import makeExceptionProjection
|
||||
|
||||
MAX_EVAL = 1000
|
||||
|
||||
# Cache for projection operators
|
||||
_diag_proj_op = None
|
||||
_full_proj_op = None
|
||||
|
||||
def _get_diag_proj_op(batch_shape, dim):
|
||||
global _diag_proj_op
|
||||
if _diag_proj_op is None:
|
||||
_diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection(
|
||||
batch_shape, dim, max_eval=MAX_EVAL)
|
||||
return _diag_proj_op
|
||||
|
||||
def _get_full_proj_op(batch_shape, dim):
|
||||
global _full_proj_op
|
||||
if _full_proj_op is None:
|
||||
_full_proj_op = cpp_projection.BatchedCovOnlyProjection(
|
||||
batch_shape, dim, max_eval=MAX_EVAL)
|
||||
return _full_proj_op
|
||||
|
||||
class KLProjection(BaseProjection):
|
||||
"""KL divergence-based projection for Gaussian policies.
|
||||
|
||||
@ -166,6 +184,49 @@ class KLProjection(BaseProjection):
|
||||
if key not in policy_params or key not in old_policy_params:
|
||||
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
||||
def project_diag_covariance(cov, old_cov, eps_cov):
|
||||
"""JAX wrapper for C++ diagonal covariance projection"""
|
||||
batch_shape = cov.shape[0]
|
||||
dim = cov.shape[-1]
|
||||
|
||||
cov_np = np.asarray(cov)
|
||||
old_cov_np = np.asarray(old_cov)
|
||||
eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype)
|
||||
|
||||
p_op = _get_diag_proj_op(batch_shape, dim)
|
||||
|
||||
try:
|
||||
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
|
||||
except:
|
||||
proj_cov = cov_np # Return input on failure
|
||||
|
||||
return jnp.array(proj_cov)
|
||||
|
||||
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
|
||||
y = project_diag_covariance(cov, old_cov, eps_cov)
|
||||
return y, (cov, old_cov)
|
||||
|
||||
def project_diag_covariance_bwd(eps_cov, res, g):
|
||||
cov, old_cov = res
|
||||
|
||||
# Convert to numpy for C++ backward pass
|
||||
g_np = np.asarray(g)
|
||||
batch_shape = g_np.shape[0]
|
||||
dim = g_np.shape[-1]
|
||||
|
||||
# Get C++ projection operator
|
||||
p_op = _get_diag_proj_op(batch_shape, dim)
|
||||
|
||||
# Run C++ backward pass
|
||||
grad_cov = p_op.backward(g_np)
|
||||
|
||||
# Convert back to JAX array
|
||||
return jnp.array(grad_cov), None
|
||||
|
||||
# Register VJP rule for diagonal covariance projection
|
||||
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(3,))
|
||||
def project_full_covariance(cov, chol, old_chol, eps_cov):
|
||||
"""JAX wrapper for C++ full covariance projection"""
|
||||
@ -179,7 +240,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov):
|
||||
eps = eps_cov * np.ones(batch_shape)
|
||||
|
||||
# Create C++ projection operator directly
|
||||
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||
p_op = _get_full_proj_op(batch_shape, dim)
|
||||
|
||||
# Run C++ projection
|
||||
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
|
||||
@ -203,7 +264,7 @@ def project_full_covariance_bwd(eps_cov, res, g):
|
||||
dim = g_np.shape[-1]
|
||||
|
||||
# Get C++ projection operator
|
||||
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||
p_op = _get_full_proj_op(batch_shape, dim)
|
||||
|
||||
# Run C++ backward pass
|
||||
grad_cov = p_op.backward(g_np)
|
||||
@ -214,48 +275,5 @@ def project_full_covariance_bwd(eps_cov, res, g):
|
||||
# Register VJP rule for full covariance projection
|
||||
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
||||
def project_diag_covariance(cov, old_cov, eps_cov):
|
||||
"""JAX wrapper for C++ diagonal covariance projection"""
|
||||
# Convert JAX arrays to numpy for C++ function
|
||||
cov_np = np.asarray(cov)
|
||||
old_cov_np = np.asarray(old_cov)
|
||||
batch_shape = cov_np.shape[0]
|
||||
dim = cov_np.shape[-1]
|
||||
eps = eps_cov * np.ones(batch_shape)
|
||||
|
||||
# Create C++ projection operator directly
|
||||
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||
|
||||
# Run C++ projection
|
||||
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
|
||||
|
||||
# Convert back to JAX array
|
||||
return jnp.array(proj_cov)
|
||||
|
||||
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
|
||||
y = project_diag_covariance(cov, old_cov, eps_cov)
|
||||
return y, (cov, old_cov)
|
||||
|
||||
def project_diag_covariance_bwd(eps_cov, res, g):
|
||||
cov, old_cov = res
|
||||
|
||||
# Convert to numpy for C++ backward pass
|
||||
g_np = np.asarray(g)
|
||||
batch_shape = g_np.shape[0]
|
||||
dim = g_np.shape[-1]
|
||||
|
||||
# Get C++ projection operator
|
||||
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||
|
||||
# Run C++ backward pass
|
||||
grad_cov = p_op.backward(g_np)
|
||||
|
||||
# Convert back to JAX array
|
||||
return jnp.array(grad_cov), None
|
||||
|
||||
# Register VJP rule for diagonal covariance projection
|
||||
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
|
||||
|
||||
if not cpp_projection_available:
|
||||
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")
|
@ -1,6 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict, Tuple
|
||||
import jax
|
||||
|
||||
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
||||
"""
|
||||
@ -22,7 +23,8 @@ class WassersteinProjection(BaseProjection):
|
||||
|
||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
|
||||
if self.full_cov:
|
||||
print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.")
|
||||
|
||||
mean = policy_params["loc"] # shape: (batch_size, dim)
|
||||
old_mean = old_policy_params["loc"]
|
||||
@ -73,27 +75,47 @@ class WassersteinProjection(BaseProjection):
|
||||
|
||||
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
||||
scale_part: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Project scale parameters (standard deviations for diagonal case)"""
|
||||
diff = scale - old_scale
|
||||
norm = jnp.sqrt(scale_part)
|
||||
"""Project scale parameters using multiplicative update.
|
||||
|
||||
if scale.ndim == 2: # Batched scale
|
||||
norm = norm[..., None]
|
||||
Args:
|
||||
scale: Current scale/sqrt of covariance
|
||||
old_scale: Previous scale/sqrt of covariance
|
||||
scale_part: W2 distance between scales
|
||||
|
||||
return jnp.where(norm > self.cov_bound,
|
||||
old_scale + diff * self.cov_bound / norm,
|
||||
scale)
|
||||
Returns:
|
||||
Projected scale that satisfies the trust region constraint
|
||||
"""
|
||||
# Check if projection needed
|
||||
cov_mask = scale_part > self.cov_bound
|
||||
|
||||
# Compute eta (multiplier for the update)
|
||||
batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1]
|
||||
eta = jnp.ones(batch_shape, dtype=scale.dtype)
|
||||
eta = jnp.where(cov_mask,
|
||||
jnp.sqrt(scale_part / self.cov_bound) - 1.,
|
||||
eta)
|
||||
eta = jnp.maximum(-eta, eta)
|
||||
|
||||
# Multiplicative update with correct broadcasting
|
||||
if scale.ndim > 2: # Full covariance case
|
||||
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
|
||||
(1. + eta + 1e-16)[..., None, None]
|
||||
mask = cov_mask[..., None, None]
|
||||
else: # Diagonal case
|
||||
new_scale = (scale + eta[..., None] * old_scale) / \
|
||||
(1. + eta + 1e-16)[..., None]
|
||||
mask = cov_mask[..., None]
|
||||
|
||||
return jnp.where(mask, new_scale, scale)
|
||||
|
||||
def _gaussian_wasserstein(self, p, q):
|
||||
mean, scale = p
|
||||
mean_other, scale_other = q
|
||||
|
||||
# Keep batch dimension by only summing over feature dimension
|
||||
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
|
||||
# Euclidean distance for mean part (we're in diagonal case)
|
||||
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
|
||||
|
||||
if scale.ndim == mean.ndim: # Batched scale
|
||||
# Standard W2 objective for covariance (diagonal case)
|
||||
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
|
||||
else: # Non-contextual scale (single scale for all batches)
|
||||
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
|
||||
|
||||
return mean_part, cov_part
|
@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
|
||||
assert jnp.all(jnp.isfinite(proj_params["scale"]))
|
||||
assert jnp.all(proj_params["scale"] > 0)
|
||||
|
||||
# Only check KL bounds for KL projection (and W2, which should approx hold as well)
|
||||
if ProjectionClass in [KLProjection, WassersteinProjection]:
|
||||
# Only check KL bounds for KL projection
|
||||
if ProjectionClass in [KLProjection]:
|
||||
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
|
||||
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user