This commit is contained in:
Dominik Moritz Roth 2024-12-21 18:31:01 +01:00
parent 44eb3335ff
commit 8e991ae05b
3 changed files with 127 additions and 74 deletions

View File

@ -1,6 +1,7 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict
import jax
class FrobeniusProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -64,12 +65,26 @@ class FrobeniusProjection(BaseProjection):
old_mean, old_cov = q
if self.scale_prec:
prec_old = jnp.linalg.inv(old_cov)
mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
# Mahalanobis distance for mean
diff = mean - old_mean
if old_cov.ndim == mean.ndim: # diagonal case
mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1)
else:
solved = jax.scipy.linalg.solve_triangular(
old_cov, diff[..., None], lower=True
)
mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
else:
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
# Frobenius norm for covariance
if cov.ndim == mean.ndim: # diagonal case
diff = old_cov - cov
cov_part = jnp.sum(jnp.square(diff), axis=-1)
else:
diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \
jnp.matmul(cov, jnp.swapaxes(cov, -1, -2))
cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1))
return mean_part, cov_part
@ -83,7 +98,7 @@ class FrobeniusProjection(BaseProjection):
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
batch_shape = cov.shape[:-2]
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
cov_mask = cov_part > self.cov_bound
eta = jnp.ones(batch_shape, dtype=cov.dtype)
@ -93,12 +108,10 @@ class FrobeniusProjection(BaseProjection):
eta = jnp.maximum(-eta, eta)
if self.full_cov:
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
(1. + eta + 1e-16)[..., None, None]
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
return jnp.linalg.cholesky(proj_cov)
else:
# For diagonal case, simple broadcasting
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
new_cov, cov)
return proj_cov
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)

View File

@ -13,6 +13,24 @@ from .exception_projection import makeExceptionProjection
MAX_EVAL = 1000
# Cache for projection operators
_diag_proj_op = None
_full_proj_op = None
def _get_diag_proj_op(batch_shape, dim):
global _diag_proj_op
if _diag_proj_op is None:
_diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _diag_proj_op
def _get_full_proj_op(batch_shape, dim):
global _full_proj_op
if _full_proj_op is None:
_full_proj_op = cpp_projection.BatchedCovOnlyProjection(
batch_shape, dim, max_eval=MAX_EVAL)
return _full_proj_op
class KLProjection(BaseProjection):
"""KL divergence-based projection for Gaussian policies.
@ -166,6 +184,49 @@ class KLProjection(BaseProjection):
if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters")
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype)
p_op = _get_diag_proj_op(batch_shape, dim)
try:
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
except:
proj_cov = cov_np # Return input on failure
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = _get_diag_proj_op(batch_shape, dim)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def project_full_covariance(cov, chol, old_chol, eps_cov):
"""JAX wrapper for C++ full covariance projection"""
@ -179,7 +240,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov):
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
p_op = _get_full_proj_op(batch_shape, dim)
# Run C++ projection
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
@ -203,7 +264,7 @@ def project_full_covariance_bwd(eps_cov, res, g):
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
p_op = _get_full_proj_op(batch_shape, dim)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
@ -214,48 +275,5 @@ def project_full_covariance_bwd(eps_cov, res, g):
# Register VJP rule for full covariance projection
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
# Convert JAX arrays to numpy for C++ function
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
batch_shape = cov_np.shape[0]
dim = cov_np.shape[-1]
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
# Convert back to JAX array
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
if not cpp_projection_available:
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")

View File

@ -1,6 +1,7 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict, Tuple
import jax
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
"""
@ -22,7 +23,8 @@ class WassersteinProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
if self.full_cov:
print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.")
mean = policy_params["loc"] # shape: (batch_size, dim)
old_mean = old_policy_params["loc"]
@ -73,27 +75,47 @@ class WassersteinProjection(BaseProjection):
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters (standard deviations for diagonal case)"""
diff = scale - old_scale
norm = jnp.sqrt(scale_part)
"""Project scale parameters using multiplicative update.
if scale.ndim == 2: # Batched scale
norm = norm[..., None]
Args:
scale: Current scale/sqrt of covariance
old_scale: Previous scale/sqrt of covariance
scale_part: W2 distance between scales
return jnp.where(norm > self.cov_bound,
old_scale + diff * self.cov_bound / norm,
scale)
Returns:
Projected scale that satisfies the trust region constraint
"""
# Check if projection needed
cov_mask = scale_part > self.cov_bound
# Compute eta (multiplier for the update)
batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1]
eta = jnp.ones(batch_shape, dtype=scale.dtype)
eta = jnp.where(cov_mask,
jnp.sqrt(scale_part / self.cov_bound) - 1.,
eta)
eta = jnp.maximum(-eta, eta)
# Multiplicative update with correct broadcasting
if scale.ndim > 2: # Full covariance case
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
(1. + eta + 1e-16)[..., None, None]
mask = cov_mask[..., None, None]
else: # Diagonal case
new_scale = (scale + eta[..., None] * old_scale) / \
(1. + eta + 1e-16)[..., None]
mask = cov_mask[..., None]
return jnp.where(mask, new_scale, scale)
def _gaussian_wasserstein(self, p, q):
mean, scale = p
mean_other, scale_other = q
# Keep batch dimension by only summing over feature dimension
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
# Euclidean distance for mean part (we're in diagonal case)
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
if scale.ndim == mean.ndim: # Batched scale
# Standard W2 objective for covariance (diagonal case)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
else: # Non-contextual scale (single scale for all batches)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
return mean_part, cov_part