commit 9f85217a47413421986537b363fec021fb4cc0ed Author: Dominik Roth Date: Wed Dec 11 18:33:40 2024 +0100 to git we go diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/itpal_jax/__init__.py b/itpal_jax/__init__.py new file mode 100644 index 0000000..8282c2a --- /dev/null +++ b/itpal_jax/__init__.py @@ -0,0 +1,24 @@ +""" +JAX implementations of various policy projection methods. + +Available projections: + - BaseProjection: Abstract base class for projections + - IdentityProjection: Simple identity mapping + - FrobeniusProjection: Frobenius norm-based projection + - WassersteinProjection: Wasserstein distance-based projection + - KLProjection: KL divergence-based projection (requires C++ backend) +""" + +from .base_projection import BaseProjection +from .identity_projection import IdentityProjection +from .frobenius_projection import FrobeniusProjection +from .wasserstein_projection import WassersteinProjection +from .kl_projection import KLProjection + +__all__ = [ + 'BaseProjection', + 'IdentityProjection', + 'FrobeniusProjection', + 'WassersteinProjection', + 'KLProjection', +] diff --git a/itpal_jax/base_projection.py b/itpal_jax/base_projection.py new file mode 100644 index 0000000..3ec32ea --- /dev/null +++ b/itpal_jax/base_projection.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import Dict +import jax.numpy as jnp + +class BaseProjection(ABC): + def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, + cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False): + self.trust_region_coeff = trust_region_coeff + self.mean_bound = mean_bound + self.cov_bound = cov_bound + self.full_cov = full_cov + self.contextual_std = contextual_std + + @abstractmethod + def project(self, policy_params: Dict[str, jnp.ndarray], + old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: + pass + + @abstractmethod + def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], + proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + pass + + def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + if not self.full_cov: + return jnp.diag(params["scale"] ** 2) + else: + scale_tril = params["scale_tril"] + return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2)) + + def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray: + if not self.full_cov: + return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) + else: + return jnp.linalg.cholesky(cov) \ No newline at end of file diff --git a/itpal_jax/exception_projection.py b/itpal_jax/exception_projection.py new file mode 100644 index 0000000..487b40e --- /dev/null +++ b/itpal_jax/exception_projection.py @@ -0,0 +1,20 @@ +from .base_projection import BaseProjection +from functools import partial +import jax.numpy as jnp +from typing import Dict + +class ExceptionProjection(BaseProjection): + def __init__(self, msg, *args, **kwargs): + self.msg = msg + raise Exception(msg) + + def project(self, policy_params: Dict[str, jnp.ndarray], + old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: + pass + + def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], + proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + pass + +def makeExceptionProjection(msg: str): + return partial(ExceptionProjection, msg) \ No newline at end of file diff --git a/itpal_jax/frobenius_projection.py b/itpal_jax/frobenius_projection.py new file mode 100644 index 0000000..3cd3cf7 --- /dev/null +++ b/itpal_jax/frobenius_projection.py @@ -0,0 +1,78 @@ +import jax.numpy as jnp +from .base_projection import BaseProjection +from typing import Dict + +class FrobeniusProjection(BaseProjection): + def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, + cov_bound: float = 0.01, scale_prec: bool = False, + contextual_std: bool = True, full_cov: bool = False): + super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, + cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov) + self.scale_prec = scale_prec + + def project(self, policy_params: Dict[str, jnp.ndarray], + old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: + mean = policy_params["loc"] + old_mean = old_policy_params["loc"] + + cov = self._calc_covariance(policy_params) + old_cov = self._calc_covariance(old_policy_params) + + mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) + + proj_mean = self._mean_projection(mean, old_mean, mean_part) + proj_cov = self._cov_projection(cov, old_cov, cov_part) + + scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov) + return {"loc": proj_mean, "scale": scale_or_scale_tril} + + def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], + proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + mean = policy_params["loc"] + proj_mean = proj_policy_params["loc"] + + cov = self._calc_covariance(policy_params) + proj_cov = self._calc_covariance(proj_policy_params) + + mean_diff = jnp.sum(jnp.square(mean - proj_mean), axis=-1) + cov_diff = jnp.sum(jnp.square(cov - proj_cov), axis=(-2, -1)) + + return (mean_diff + cov_diff).mean() * self.trust_region_coeff + + def _gaussian_frobenius(self, p, q): + mean, cov = p + old_mean, old_cov = q + + if self.scale_prec: + prec_old = jnp.linalg.inv(old_cov) + mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1) + cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1] + else: + mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1) + cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1)) + + return mean_part, cov_part + + def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, + mean_part: jnp.ndarray) -> jnp.ndarray: + diff = mean - old_mean + norm = jnp.sqrt(mean_part) + return jnp.where(norm > self.mean_bound, + old_mean + diff * self.mean_bound / norm[..., None], + mean) + + def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, + cov_part: jnp.ndarray) -> jnp.ndarray: + batch_shape = cov.shape[:-2] + cov_mask = cov_part > self.cov_bound + + eta = jnp.ones(batch_shape, dtype=cov.dtype) + eta = jnp.where(cov_mask, + jnp.sqrt(cov_part / self.cov_bound) - 1., + eta) + eta = jnp.maximum(-eta, eta) + + new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] + proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov) + + return proj_cov \ No newline at end of file diff --git a/itpal_jax/identity_projection.py b/itpal_jax/identity_projection.py new file mode 100644 index 0000000..a27d202 --- /dev/null +++ b/itpal_jax/identity_projection.py @@ -0,0 +1,17 @@ +import jax.numpy as jnp +from .base_projection import BaseProjection +from typing import Dict + +class IdentityProjection(BaseProjection): + def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, + cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False): + super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, + cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov) + + def project(self, policy_params: Dict[str, jnp.ndarray], + old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: + return policy_params + + def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], + proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + return jnp.array(0.0) \ No newline at end of file diff --git a/itpal_jax/kl_projection.py b/itpal_jax/kl_projection.py new file mode 100644 index 0000000..17dd332 --- /dev/null +++ b/itpal_jax/kl_projection.py @@ -0,0 +1,244 @@ +try: + import cpp_projection + cpp_projection_available = True +except ImportError: + cpp_projection_available = False +import jax +import jax.numpy as jnp +import numpy as np +from functools import partial +from typing import Dict, Tuple, Any +from .base_projection import BaseProjection +from .exception_projection import makeExceptionProjection + +MAX_EVAL = 1000 + +class KLProjection(BaseProjection): + """KL divergence-based projection for Gaussian policies. + + This class implements KL divergence projection using JAX, with C++ backend + for efficient projection operations. It supports both diagonal and full + covariance matrices. + + Args: + trust_region_coeff (float): Coefficient for trust region loss + mean_bound (float): Bound for mean projection + cov_bound (float): Bound for covariance projection + contextual_std (bool): Whether to use contextual standard deviations + """ + def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, + cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False): + super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, + cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov) + + def project(self, policy_params: Dict[str, jnp.ndarray], + old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: + self._validate_inputs(policy_params, old_policy_params) + mean, scale_or_tril = policy_params["loc"], policy_params["scale"] + old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"] + + mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), + (old_mean, old_scale_or_tril)) + + if not self.contextual_std: + scale_or_tril = scale_or_tril[:1] + old_scale_or_tril = old_scale_or_tril[:1] + cov_part = cov_part[:1] + + proj_mean = self._mean_projection(mean, old_mean, mean_part) + proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part) + + if not self.contextual_std: + proj_scale_or_tril = jnp.broadcast_to( + proj_scale_or_tril, + (mean.shape[0],) + proj_scale_or_tril.shape[1:] + ) + + return {"loc": proj_mean, "scale": proj_scale_or_tril} + + def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], + proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + mean, scale_or_tril = policy_params["loc"], policy_params["scale"] + proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params["scale"] + kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) + return jnp.mean(kl) * self.trust_region_coeff + + def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray], + q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: + mean, scale_or_tril = p + mean_other, scale_or_tril_other = q + k = mean.shape[-1] + + maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other) + + det_term = self._log_determinant(scale_or_tril) + det_term_other = self._log_determinant(scale_or_tril_other) + + if self.full_cov: + trace_part = self._batched_trace_square( + jax.scipy.linalg.solve_triangular( + scale_or_tril_other, scale_or_tril, lower=True + ) + ) + else: + trace_part = jnp.sum((scale_or_tril / scale_or_tril_other) ** 2, axis=-1) + + cov_part = 0.5 * (trace_part - k + det_term_other - det_term) + + return maha_part, cov_part + + def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray: + diff = x - y + if self.full_cov: + solved = jax.scipy.linalg.solve_triangular( + scale_or_tril, diff[..., None], lower=True + ) + return jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1) + else: + return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1) + + def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray: + if self.full_cov: + return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1) + else: + return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1) + + def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray: + return jnp.sum(x ** 2, axis=(-2, -1)) + + def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, + mean_part: jnp.ndarray) -> jnp.ndarray: + return old_mean + (mean - old_mean) * jnp.sqrt( + self.mean_bound / (mean_part + 1e-8) + )[..., None] + + def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray: + if self.full_cov: + cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2)) + old_cov = jnp.matmul(old_scale_or_tril, jnp.swapaxes(old_scale_or_tril, -1, -2)) + else: + cov = scale_or_tril ** 2 + old_cov = old_scale_or_tril ** 2 + + mask = cov_part > self.cov_bound + proj_scale_or_tril = jnp.zeros_like(scale_or_tril) + proj_scale_or_tril = jnp.where(~mask, scale_or_tril, proj_scale_or_tril) + + if mask.any(): + if self.full_cov: + proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) + is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) & mask + proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril) + mask = mask & ~is_invalid + chol = jnp.linalg.cholesky(proj_cov) + proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril) + else: + proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) + is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) | + jnp.isinf(proj_cov.mean(axis=-1)) | + (proj_cov.min(axis=-1) < 0)) & mask + proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril) + mask = mask & ~is_invalid + proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), proj_scale_or_tril) + + return proj_scale_or_tril + + def _validate_inputs(self, policy_params, old_policy_params): + required_keys = ["loc", "scale"] + for key in required_keys: + if key not in policy_params or key not in old_policy_params: + raise KeyError(f"Missing required key '{key}' in policy parameters") + +@partial(jax.custom_vjp, nondiff_argnums=(3,)) +def project_full_covariance(cov, chol, old_chol, eps_cov): + """JAX wrapper for C++ full covariance projection""" + try: + # Convert JAX arrays to numpy for C++ function + cov_np = np.asarray(cov) + chol_np = np.asarray(chol) + old_chol_np = np.asarray(old_chol) + batch_shape = cov_np.shape[0] + dim = cov_np.shape[-1] + eps = eps_cov * np.ones(batch_shape) + + # Create C++ projection operator directly + p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) + + # Run C++ projection + proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np) + + # Convert back to JAX array + return jnp.array(proj_cov) + except Exception as e: + print(f"Full covariance projection failed: {e}") + return old_chol # Return old values on failure + +def project_full_covariance_fwd(cov, chol, old_chol, eps_cov): + y = project_full_covariance(cov, chol, old_chol, eps_cov) + return y, (cov, chol, old_chol) + +def project_full_covariance_bwd(eps_cov, res, g): + cov, chol, old_chol = res + + # Convert to numpy for C++ backward pass + g_np = np.asarray(g) + batch_shape = g_np.shape[0] + dim = g_np.shape[-1] + + # Get C++ projection operator + p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) + + # Run C++ backward pass + grad_cov = p_op.backward(g_np) + + # Convert back to JAX array + return jnp.array(grad_cov), None, None + +# Register VJP rule for full covariance projection +project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd) + +@partial(jax.custom_vjp, nondiff_argnums=(2,)) +def project_diag_covariance(cov, old_cov, eps_cov): + """JAX wrapper for C++ diagonal covariance projection""" + # Convert JAX arrays to numpy for C++ function + cov_np = np.asarray(cov) + old_cov_np = np.asarray(old_cov) + batch_shape = cov_np.shape[0] + dim = cov_np.shape[-1] + eps = eps_cov * np.ones(batch_shape) + + # Create C++ projection operator directly + p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) + + # Run C++ projection + proj_cov = p_op.forward(eps, old_cov_np, cov_np) + + # Convert back to JAX array + return jnp.array(proj_cov) + +def project_diag_covariance_fwd(cov, old_cov, eps_cov): + y = project_diag_covariance(cov, old_cov, eps_cov) + return y, (cov, old_cov) + +def project_diag_covariance_bwd(eps_cov, res, g): + cov, old_cov = res + + # Convert to numpy for C++ backward pass + g_np = np.asarray(g) + batch_shape = g_np.shape[0] + dim = g_np.shape[-1] + + # Get C++ projection operator + p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL) + + # Run C++ backward pass + grad_cov = p_op.backward(g_np) + + # Convert back to JAX array + return jnp.array(grad_cov), None + +# Register VJP rule for diagonal covariance projection +project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd) + +if not cpp_projection_available: + KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.") \ No newline at end of file diff --git a/itpal_jax/wasserstein_projection.py b/itpal_jax/wasserstein_projection.py new file mode 100644 index 0000000..aa224cc --- /dev/null +++ b/itpal_jax/wasserstein_projection.py @@ -0,0 +1,108 @@ +import jax.numpy as jnp +from .base_projection import BaseProjection +from typing import Dict, Tuple + +def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray: + """ + 'Converts' scale_tril to scale_sqrt. + + For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition. + But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root. + """ + return scale_tril + +def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray], + q: Tuple[jnp.ndarray, jnp.ndarray], + scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]: + mean, scale_or_sqrt = p + mean_other, scale_or_sqrt_other = q + + mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) + + if scale_or_sqrt.ndim == mean.ndim: # Diagonal case + cov = scale_or_sqrt ** 2 + cov_other = scale_or_sqrt_other ** 2 + if scale_prec: + identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype) + sqrt_inv_other = 1 / scale_or_sqrt_other + c = sqrt_inv_other ** 2 * cov + cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1) + else: + cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1) + else: # Full covariance case + # Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition + cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2)) + cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2)) + if scale_prec: + identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype) + sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity) + c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2) + cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt) + else: + cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt) + + return mean_part, cov_part + +class WassersteinProjection(BaseProjection): + def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, + cov_bound: float = 0.01, scale_prec: bool = False, + contextual_std: bool = True, full_cov: bool = False): + assert not full_cov, "Full covariance is not supported for Wasserstein projection" + super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, + cov_bound=cov_bound, contextual_std=contextual_std, full_cov=False) + self.scale_prec = scale_prec + + def project(self, policy_params: Dict[str, jnp.ndarray], + old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: + mean = policy_params["loc"] + old_mean = old_policy_params["loc"] + scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"]) + old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"]) + + mean_part, cov_part = gaussian_wasserstein_commutative( + (mean, scale_or_sqrt), + (old_mean, old_scale_or_sqrt), + self.scale_prec + ) + + proj_mean = self._mean_projection(mean, old_mean, mean_part) + proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part) + + return {"loc": proj_mean, "scale": proj_scale_or_sqrt} + + def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], + proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + mean = policy_params["loc"] + proj_mean = proj_policy_params["loc"] + scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"]) + proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"]) + mean_part, cov_part = gaussian_wasserstein_commutative( + (mean, scale_or_sqrt), + (proj_mean, proj_scale_or_sqrt), + self.scale_prec + ) + w2 = mean_part + cov_part + return w2.mean() * self.trust_region_coeff + + def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, + mean_part: jnp.ndarray) -> jnp.ndarray: + diff = mean - old_mean + norm = jnp.sqrt(mean_part) + return jnp.where(norm > self.mean_bound, + old_mean + diff * self.mean_bound / norm[..., None], + mean) + + def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray, + cov_part: jnp.ndarray) -> jnp.ndarray: + if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case + diff = scale_or_sqrt - old_scale_or_sqrt + norm = jnp.sqrt(cov_part) + return jnp.where(norm > self.cov_bound, + old_scale_or_sqrt + diff * self.cov_bound / norm[..., None], + scale_or_sqrt) + else: # Full covariance case + diff = scale_or_sqrt - old_scale_or_sqrt + norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True) + return jnp.where(norm > self.cov_bound, + old_scale_or_sqrt + diff * self.cov_bound / norm, + scale_or_sqrt) \ No newline at end of file