From 8e991ae05bf7aecab2cf14d371bff00aa93dc7b0 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 21 Dec 2024 18:31:01 +0100 Subject: [PATCH] Fixes --- itpal_jax/frobenius_projection.py | 41 +++++++---- itpal_jax/kl_projection.py | 108 ++++++++++++++++------------ itpal_jax/wasserstein_projection.py | 52 ++++++++++---- 3 files changed, 127 insertions(+), 74 deletions(-) diff --git a/itpal_jax/frobenius_projection.py b/itpal_jax/frobenius_projection.py index 2ec1d7e..2faadfb 100644 --- a/itpal_jax/frobenius_projection.py +++ b/itpal_jax/frobenius_projection.py @@ -1,6 +1,7 @@ import jax.numpy as jnp from .base_projection import BaseProjection from typing import Dict +import jax class FrobeniusProjection(BaseProjection): def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, @@ -64,12 +65,26 @@ class FrobeniusProjection(BaseProjection): old_mean, old_cov = q if self.scale_prec: - prec_old = jnp.linalg.inv(old_cov) - mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1) - cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1] + # Mahalanobis distance for mean + diff = mean - old_mean + if old_cov.ndim == mean.ndim: # diagonal case + mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1) + else: + solved = jax.scipy.linalg.solve_triangular( + old_cov, diff[..., None], lower=True + ) + mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1) else: mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1) - cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1)) + + # Frobenius norm for covariance + if cov.ndim == mean.ndim: # diagonal case + diff = old_cov - cov + cov_part = jnp.sum(jnp.square(diff), axis=-1) + else: + diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \ + jnp.matmul(cov, jnp.swapaxes(cov, -1, -2)) + cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1)) return mean_part, cov_part @@ -83,22 +98,20 @@ class FrobeniusProjection(BaseProjection): def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray: - batch_shape = cov.shape[:-2] + batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1] cov_mask = cov_part > self.cov_bound eta = jnp.ones(batch_shape, dtype=cov.dtype) eta = jnp.where(cov_mask, - jnp.sqrt(cov_part / self.cov_bound) - 1., - eta) + jnp.sqrt(cov_part / self.cov_bound) - 1., + eta) eta = jnp.maximum(-eta, eta) if self.full_cov: - new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] + new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \ + (1. + eta + 1e-16)[..., None, None] + proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov) + return jnp.linalg.cholesky(proj_cov) else: - # For diagonal case, simple broadcasting new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None] - - proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None], - new_cov, cov) - - return proj_cov \ No newline at end of file + return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov) \ No newline at end of file diff --git a/itpal_jax/kl_projection.py b/itpal_jax/kl_projection.py index d75feef..4a66403 100644 --- a/itpal_jax/kl_projection.py +++ b/itpal_jax/kl_projection.py @@ -13,6 +13,24 @@ from .exception_projection import makeExceptionProjection MAX_EVAL = 1000 +# Cache for projection operators +_diag_proj_op = None +_full_proj_op = None + +def _get_diag_proj_op(batch_shape, dim): + global _diag_proj_op + if _diag_proj_op is None: + _diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection( + batch_shape, dim, max_eval=MAX_EVAL) + return _diag_proj_op + +def _get_full_proj_op(batch_shape, dim): + global _full_proj_op + if _full_proj_op is None: + _full_proj_op = cpp_projection.BatchedCovOnlyProjection( + batch_shape, dim, max_eval=MAX_EVAL) + return _full_proj_op + class KLProjection(BaseProjection): """KL divergence-based projection for Gaussian policies. @@ -166,6 +184,49 @@ class KLProjection(BaseProjection): if key not in policy_params or key not in old_policy_params: raise KeyError(f"Missing required key '{key}' in policy parameters") +@partial(jax.custom_vjp, nondiff_argnums=(2,)) +def project_diag_covariance(cov, old_cov, eps_cov): + """JAX wrapper for C++ diagonal covariance projection""" + batch_shape = cov.shape[0] + dim = cov.shape[-1] + + cov_np = np.asarray(cov) + old_cov_np = np.asarray(old_cov) + eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype) + + p_op = _get_diag_proj_op(batch_shape, dim) + + try: + proj_cov = p_op.forward(eps, old_cov_np, cov_np) + except: + proj_cov = cov_np # Return input on failure + + return jnp.array(proj_cov) + +def project_diag_covariance_fwd(cov, old_cov, eps_cov): + y = project_diag_covariance(cov, old_cov, eps_cov) + return y, (cov, old_cov) + +def project_diag_covariance_bwd(eps_cov, res, g): + cov, old_cov = res + + # Convert to numpy for C++ backward pass + g_np = np.asarray(g) + batch_shape = g_np.shape[0] + dim = g_np.shape[-1] + + # Get C++ projection operator + p_op = _get_diag_proj_op(batch_shape, dim) + + # Run C++ backward pass + grad_cov = p_op.backward(g_np) + + # Convert back to JAX array + return jnp.array(grad_cov), None + +# Register VJP rule for diagonal covariance projection +project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd) + @partial(jax.custom_vjp, nondiff_argnums=(3,)) def project_full_covariance(cov, chol, old_chol, eps_cov): """JAX wrapper for C++ full covariance projection""" @@ -179,7 +240,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov): eps = eps_cov * np.ones(batch_shape) # Create C++ projection operator directly - p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) + p_op = _get_full_proj_op(batch_shape, dim) # Run C++ projection proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np) @@ -203,7 +264,7 @@ def project_full_covariance_bwd(eps_cov, res, g): dim = g_np.shape[-1] # Get C++ projection operator - p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) + p_op = _get_full_proj_op(batch_shape, dim) # Run C++ backward pass grad_cov = p_op.backward(g_np) @@ -214,48 +275,5 @@ def project_full_covariance_bwd(eps_cov, res, g): # Register VJP rule for full covariance projection project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd) -@partial(jax.custom_vjp, nondiff_argnums=(2,)) -def project_diag_covariance(cov, old_cov, eps_cov): - """JAX wrapper for C++ diagonal covariance projection""" - # Convert JAX arrays to numpy for C++ function - cov_np = np.asarray(cov) - old_cov_np = np.asarray(old_cov) - batch_shape = cov_np.shape[0] - dim = cov_np.shape[-1] - eps = eps_cov * np.ones(batch_shape) - - # Create C++ projection operator directly - p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) - - # Run C++ projection - proj_cov = p_op.forward(eps, old_cov_np, cov_np) - - # Convert back to JAX array - return jnp.array(proj_cov) - -def project_diag_covariance_fwd(cov, old_cov, eps_cov): - y = project_diag_covariance(cov, old_cov, eps_cov) - return y, (cov, old_cov) - -def project_diag_covariance_bwd(eps_cov, res, g): - cov, old_cov = res - - # Convert to numpy for C++ backward pass - g_np = np.asarray(g) - batch_shape = g_np.shape[0] - dim = g_np.shape[-1] - - # Get C++ projection operator - p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) - - # Run C++ backward pass - grad_cov = p_op.backward(g_np) - - # Convert back to JAX array - return jnp.array(grad_cov), None - -# Register VJP rule for diagonal covariance projection -project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd) - if not cpp_projection_available: KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.") \ No newline at end of file diff --git a/itpal_jax/wasserstein_projection.py b/itpal_jax/wasserstein_projection.py index ecc5237..ddf33d4 100644 --- a/itpal_jax/wasserstein_projection.py +++ b/itpal_jax/wasserstein_projection.py @@ -1,6 +1,7 @@ import jax.numpy as jnp from .base_projection import BaseProjection from typing import Dict, Tuple +import jax def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray: """ @@ -22,7 +23,8 @@ class WassersteinProjection(BaseProjection): def project(self, policy_params: Dict[str, jnp.ndarray], old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: - assert not self.full_cov, "Wasserstein projection only supports diagonal covariance" + if self.full_cov: + print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.") mean = policy_params["loc"] # shape: (batch_size, dim) old_mean = old_policy_params["loc"] @@ -73,27 +75,47 @@ class WassersteinProjection(BaseProjection): def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, scale_part: jnp.ndarray) -> jnp.ndarray: - """Project scale parameters (standard deviations for diagonal case)""" - diff = scale - old_scale - norm = jnp.sqrt(scale_part) + """Project scale parameters using multiplicative update. - if scale.ndim == 2: # Batched scale - norm = norm[..., None] + Args: + scale: Current scale/sqrt of covariance + old_scale: Previous scale/sqrt of covariance + scale_part: W2 distance between scales - return jnp.where(norm > self.cov_bound, - old_scale + diff * self.cov_bound / norm, - scale) + Returns: + Projected scale that satisfies the trust region constraint + """ + # Check if projection needed + cov_mask = scale_part > self.cov_bound + + # Compute eta (multiplier for the update) + batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1] + eta = jnp.ones(batch_shape, dtype=scale.dtype) + eta = jnp.where(cov_mask, + jnp.sqrt(scale_part / self.cov_bound) - 1., + eta) + eta = jnp.maximum(-eta, eta) + + # Multiplicative update with correct broadcasting + if scale.ndim > 2: # Full covariance case + new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \ + (1. + eta + 1e-16)[..., None, None] + mask = cov_mask[..., None, None] + else: # Diagonal case + new_scale = (scale + eta[..., None] * old_scale) / \ + (1. + eta + 1e-16)[..., None] + mask = cov_mask[..., None] + + return jnp.where(mask, new_scale, scale) def _gaussian_wasserstein(self, p, q): mean, scale = p mean_other, scale_other = q - # Keep batch dimension by only summing over feature dimension - mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,) + # Euclidean distance for mean part (we're in diagonal case) + mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) - if scale.ndim == mean.ndim: # Batched scale - cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1) - else: # Non-contextual scale (single scale for all batches) - cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale) + # Standard W2 objective for covariance (diagonal case) + cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1) return mean_part, cov_part \ No newline at end of file