Compare commits
3 Commits
44eb3335ff
...
de2b9a10d6
Author | SHA1 | Date | |
---|---|---|---|
de2b9a10d6 | |||
9fb0014a99 | |||
8e991ae05b |
@ -65,9 +65,7 @@ pytest tests/test_projections.py
|
|||||||
*Note*: The test suite verifies:
|
*Note*: The test suite verifies:
|
||||||
|
|
||||||
1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
|
1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
|
||||||
2. KL bounds are actually (approximately) met for:
|
2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance)
|
||||||
- KL projection (both diagonal and full covariance)
|
|
||||||
- Wasserstein projection (diagonal covariance only)
|
|
||||||
3. Gradients can be computed through all projections:
|
3. Gradients can be computed through all projections:
|
||||||
- Both through projection operation and trust region loss
|
- Both through projection operation and trust region loss
|
||||||
- Gradients have correct shapes and are finite
|
- Gradients have correct shapes and are finite
|
@ -1,6 +1,7 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
import jax
|
||||||
|
|
||||||
class FrobeniusProjection(BaseProjection):
|
class FrobeniusProjection(BaseProjection):
|
||||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
@ -64,12 +65,26 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
old_mean, old_cov = q
|
old_mean, old_cov = q
|
||||||
|
|
||||||
if self.scale_prec:
|
if self.scale_prec:
|
||||||
prec_old = jnp.linalg.inv(old_cov)
|
# Mahalanobis distance for mean
|
||||||
mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
|
diff = mean - old_mean
|
||||||
cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
|
if old_cov.ndim == mean.ndim: # diagonal case
|
||||||
|
mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1)
|
||||||
|
else:
|
||||||
|
solved = jax.scipy.linalg.solve_triangular(
|
||||||
|
old_cov, diff[..., None], lower=True
|
||||||
|
)
|
||||||
|
mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
|
||||||
else:
|
else:
|
||||||
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
|
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
|
||||||
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
|
|
||||||
|
# Frobenius norm for covariance
|
||||||
|
if cov.ndim == mean.ndim: # diagonal case
|
||||||
|
diff = old_cov - cov
|
||||||
|
cov_part = jnp.sum(jnp.square(diff), axis=-1)
|
||||||
|
else:
|
||||||
|
diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \
|
||||||
|
jnp.matmul(cov, jnp.swapaxes(cov, -1, -2))
|
||||||
|
cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1))
|
||||||
|
|
||||||
return mean_part, cov_part
|
return mean_part, cov_part
|
||||||
|
|
||||||
@ -83,22 +98,20 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
|
|
||||||
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
|
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
|
||||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
batch_shape = cov.shape[:-2]
|
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
|
||||||
cov_mask = cov_part > self.cov_bound
|
cov_mask = cov_part > self.cov_bound
|
||||||
|
|
||||||
eta = jnp.ones(batch_shape, dtype=cov.dtype)
|
eta = jnp.ones(batch_shape, dtype=cov.dtype)
|
||||||
eta = jnp.where(cov_mask,
|
eta = jnp.where(cov_mask,
|
||||||
jnp.sqrt(cov_part / self.cov_bound) - 1.,
|
jnp.sqrt(cov_part / self.cov_bound) - 1.,
|
||||||
eta)
|
eta)
|
||||||
eta = jnp.maximum(-eta, eta)
|
eta = jnp.maximum(-eta, eta)
|
||||||
|
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
|
||||||
|
(1. + eta + 1e-16)[..., None, None]
|
||||||
|
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
|
||||||
|
return jnp.linalg.cholesky(proj_cov)
|
||||||
else:
|
else:
|
||||||
# For diagonal case, simple broadcasting
|
|
||||||
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
||||||
|
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)
|
||||||
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
|
|
||||||
new_cov, cov)
|
|
||||||
|
|
||||||
return proj_cov
|
|
@ -13,6 +13,24 @@ from .exception_projection import makeExceptionProjection
|
|||||||
|
|
||||||
MAX_EVAL = 1000
|
MAX_EVAL = 1000
|
||||||
|
|
||||||
|
# Cache for projection operators
|
||||||
|
_diag_proj_op = None
|
||||||
|
_full_proj_op = None
|
||||||
|
|
||||||
|
def _get_diag_proj_op(batch_shape, dim):
|
||||||
|
global _diag_proj_op
|
||||||
|
if _diag_proj_op is None:
|
||||||
|
_diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection(
|
||||||
|
batch_shape, dim, max_eval=MAX_EVAL)
|
||||||
|
return _diag_proj_op
|
||||||
|
|
||||||
|
def _get_full_proj_op(batch_shape, dim):
|
||||||
|
global _full_proj_op
|
||||||
|
if _full_proj_op is None:
|
||||||
|
_full_proj_op = cpp_projection.BatchedCovOnlyProjection(
|
||||||
|
batch_shape, dim, max_eval=MAX_EVAL)
|
||||||
|
return _full_proj_op
|
||||||
|
|
||||||
class KLProjection(BaseProjection):
|
class KLProjection(BaseProjection):
|
||||||
"""KL divergence-based projection for Gaussian policies.
|
"""KL divergence-based projection for Gaussian policies.
|
||||||
|
|
||||||
@ -166,6 +184,49 @@ class KLProjection(BaseProjection):
|
|||||||
if key not in policy_params or key not in old_policy_params:
|
if key not in policy_params or key not in old_policy_params:
|
||||||
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
||||||
|
|
||||||
|
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
||||||
|
def project_diag_covariance(cov, old_cov, eps_cov):
|
||||||
|
"""JAX wrapper for C++ diagonal covariance projection"""
|
||||||
|
batch_shape = cov.shape[0]
|
||||||
|
dim = cov.shape[-1]
|
||||||
|
|
||||||
|
cov_np = np.asarray(cov)
|
||||||
|
old_cov_np = np.asarray(old_cov)
|
||||||
|
eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype)
|
||||||
|
|
||||||
|
p_op = _get_diag_proj_op(batch_shape, dim)
|
||||||
|
|
||||||
|
try:
|
||||||
|
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
|
||||||
|
except:
|
||||||
|
proj_cov = cov_np # Return input on failure
|
||||||
|
|
||||||
|
return jnp.array(proj_cov)
|
||||||
|
|
||||||
|
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
|
||||||
|
y = project_diag_covariance(cov, old_cov, eps_cov)
|
||||||
|
return y, (cov, old_cov)
|
||||||
|
|
||||||
|
def project_diag_covariance_bwd(eps_cov, res, g):
|
||||||
|
cov, old_cov = res
|
||||||
|
|
||||||
|
# Convert to numpy for C++ backward pass
|
||||||
|
g_np = np.asarray(g)
|
||||||
|
batch_shape = g_np.shape[0]
|
||||||
|
dim = g_np.shape[-1]
|
||||||
|
|
||||||
|
# Get C++ projection operator
|
||||||
|
p_op = _get_diag_proj_op(batch_shape, dim)
|
||||||
|
|
||||||
|
# Run C++ backward pass
|
||||||
|
grad_cov = p_op.backward(g_np)
|
||||||
|
|
||||||
|
# Convert back to JAX array
|
||||||
|
return jnp.array(grad_cov), None
|
||||||
|
|
||||||
|
# Register VJP rule for diagonal covariance projection
|
||||||
|
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
|
||||||
|
|
||||||
@partial(jax.custom_vjp, nondiff_argnums=(3,))
|
@partial(jax.custom_vjp, nondiff_argnums=(3,))
|
||||||
def project_full_covariance(cov, chol, old_chol, eps_cov):
|
def project_full_covariance(cov, chol, old_chol, eps_cov):
|
||||||
"""JAX wrapper for C++ full covariance projection"""
|
"""JAX wrapper for C++ full covariance projection"""
|
||||||
@ -179,7 +240,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov):
|
|||||||
eps = eps_cov * np.ones(batch_shape)
|
eps = eps_cov * np.ones(batch_shape)
|
||||||
|
|
||||||
# Create C++ projection operator directly
|
# Create C++ projection operator directly
|
||||||
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
p_op = _get_full_proj_op(batch_shape, dim)
|
||||||
|
|
||||||
# Run C++ projection
|
# Run C++ projection
|
||||||
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
|
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
|
||||||
@ -203,7 +264,7 @@ def project_full_covariance_bwd(eps_cov, res, g):
|
|||||||
dim = g_np.shape[-1]
|
dim = g_np.shape[-1]
|
||||||
|
|
||||||
# Get C++ projection operator
|
# Get C++ projection operator
|
||||||
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
p_op = _get_full_proj_op(batch_shape, dim)
|
||||||
|
|
||||||
# Run C++ backward pass
|
# Run C++ backward pass
|
||||||
grad_cov = p_op.backward(g_np)
|
grad_cov = p_op.backward(g_np)
|
||||||
@ -214,48 +275,5 @@ def project_full_covariance_bwd(eps_cov, res, g):
|
|||||||
# Register VJP rule for full covariance projection
|
# Register VJP rule for full covariance projection
|
||||||
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
|
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
|
||||||
|
|
||||||
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
|
||||||
def project_diag_covariance(cov, old_cov, eps_cov):
|
|
||||||
"""JAX wrapper for C++ diagonal covariance projection"""
|
|
||||||
# Convert JAX arrays to numpy for C++ function
|
|
||||||
cov_np = np.asarray(cov)
|
|
||||||
old_cov_np = np.asarray(old_cov)
|
|
||||||
batch_shape = cov_np.shape[0]
|
|
||||||
dim = cov_np.shape[-1]
|
|
||||||
eps = eps_cov * np.ones(batch_shape)
|
|
||||||
|
|
||||||
# Create C++ projection operator directly
|
|
||||||
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
|
||||||
|
|
||||||
# Run C++ projection
|
|
||||||
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
|
|
||||||
|
|
||||||
# Convert back to JAX array
|
|
||||||
return jnp.array(proj_cov)
|
|
||||||
|
|
||||||
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
|
|
||||||
y = project_diag_covariance(cov, old_cov, eps_cov)
|
|
||||||
return y, (cov, old_cov)
|
|
||||||
|
|
||||||
def project_diag_covariance_bwd(eps_cov, res, g):
|
|
||||||
cov, old_cov = res
|
|
||||||
|
|
||||||
# Convert to numpy for C++ backward pass
|
|
||||||
g_np = np.asarray(g)
|
|
||||||
batch_shape = g_np.shape[0]
|
|
||||||
dim = g_np.shape[-1]
|
|
||||||
|
|
||||||
# Get C++ projection operator
|
|
||||||
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
|
||||||
|
|
||||||
# Run C++ backward pass
|
|
||||||
grad_cov = p_op.backward(g_np)
|
|
||||||
|
|
||||||
# Convert back to JAX array
|
|
||||||
return jnp.array(grad_cov), None
|
|
||||||
|
|
||||||
# Register VJP rule for diagonal covariance projection
|
|
||||||
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
|
|
||||||
|
|
||||||
if not cpp_projection_available:
|
if not cpp_projection_available:
|
||||||
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")
|
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")
|
@ -1,6 +1,7 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
import jax
|
||||||
|
|
||||||
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""
|
"""
|
||||||
@ -22,7 +23,8 @@ class WassersteinProjection(BaseProjection):
|
|||||||
|
|
||||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
|
if self.full_cov:
|
||||||
|
print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.")
|
||||||
|
|
||||||
mean = policy_params["loc"] # shape: (batch_size, dim)
|
mean = policy_params["loc"] # shape: (batch_size, dim)
|
||||||
old_mean = old_policy_params["loc"]
|
old_mean = old_policy_params["loc"]
|
||||||
@ -73,27 +75,47 @@ class WassersteinProjection(BaseProjection):
|
|||||||
|
|
||||||
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
||||||
scale_part: jnp.ndarray) -> jnp.ndarray:
|
scale_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Project scale parameters (standard deviations for diagonal case)"""
|
"""Project scale parameters using multiplicative update.
|
||||||
diff = scale - old_scale
|
|
||||||
norm = jnp.sqrt(scale_part)
|
|
||||||
|
|
||||||
if scale.ndim == 2: # Batched scale
|
Args:
|
||||||
norm = norm[..., None]
|
scale: Current scale/sqrt of covariance
|
||||||
|
old_scale: Previous scale/sqrt of covariance
|
||||||
|
scale_part: W2 distance between scales
|
||||||
|
|
||||||
return jnp.where(norm > self.cov_bound,
|
Returns:
|
||||||
old_scale + diff * self.cov_bound / norm,
|
Projected scale that satisfies the trust region constraint
|
||||||
scale)
|
"""
|
||||||
|
# Check if projection needed
|
||||||
|
cov_mask = scale_part > self.cov_bound
|
||||||
|
|
||||||
|
# Compute eta (multiplier for the update)
|
||||||
|
batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1]
|
||||||
|
eta = jnp.ones(batch_shape, dtype=scale.dtype)
|
||||||
|
eta = jnp.where(cov_mask,
|
||||||
|
jnp.sqrt(scale_part / self.cov_bound) - 1.,
|
||||||
|
eta)
|
||||||
|
eta = jnp.maximum(-eta, eta)
|
||||||
|
|
||||||
|
# Multiplicative update with correct broadcasting
|
||||||
|
if scale.ndim > 2: # Full covariance case
|
||||||
|
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
|
||||||
|
(1. + eta + 1e-16)[..., None, None]
|
||||||
|
mask = cov_mask[..., None, None]
|
||||||
|
else: # Diagonal case
|
||||||
|
new_scale = (scale + eta[..., None] * old_scale) / \
|
||||||
|
(1. + eta + 1e-16)[..., None]
|
||||||
|
mask = cov_mask[..., None]
|
||||||
|
|
||||||
|
return jnp.where(mask, new_scale, scale)
|
||||||
|
|
||||||
def _gaussian_wasserstein(self, p, q):
|
def _gaussian_wasserstein(self, p, q):
|
||||||
mean, scale = p
|
mean, scale = p
|
||||||
mean_other, scale_other = q
|
mean_other, scale_other = q
|
||||||
|
|
||||||
# Keep batch dimension by only summing over feature dimension
|
# Euclidean distance for mean part (we're in diagonal case)
|
||||||
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
|
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
|
||||||
|
|
||||||
if scale.ndim == mean.ndim: # Batched scale
|
# Standard W2 objective for covariance (diagonal case)
|
||||||
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
|
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
|
||||||
else: # Non-contextual scale (single scale for all batches)
|
|
||||||
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
|
|
||||||
|
|
||||||
return mean_part, cov_part
|
return mean_part, cov_part
|
@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
|
|||||||
assert jnp.all(jnp.isfinite(proj_params["scale"]))
|
assert jnp.all(jnp.isfinite(proj_params["scale"]))
|
||||||
assert jnp.all(proj_params["scale"] > 0)
|
assert jnp.all(proj_params["scale"] > 0)
|
||||||
|
|
||||||
# Only check KL bounds for KL projection (and W2, which should approx hold as well)
|
# Only check KL bounds for KL projection
|
||||||
if ProjectionClass in [KLProjection, WassersteinProjection]:
|
if ProjectionClass in [KLProjection]:
|
||||||
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
|
kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
|
||||||
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
|
max_kl = (mean_bound + cov_bound) * 1.1 # Allow 10% margin
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user