From 4d6ed9b3ace01d28f52189c33c23595621d24c6d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 7 Jan 2025 16:54:20 +0100 Subject: [PATCH] Better jit (bool mask via matmul) --- itpal_jax/base_projection.py | 18 +++++--------- itpal_jax/frobenius_projection.py | 6 +++-- itpal_jax/kl_projection.py | 38 ++++++++++++++--------------- itpal_jax/wasserstein_projection.py | 10 ++++---- 4 files changed, 33 insertions(+), 39 deletions(-) diff --git a/itpal_jax/base_projection.py b/itpal_jax/base_projection.py index a72ae16..30fe615 100644 --- a/itpal_jax/base_projection.py +++ b/itpal_jax/base_projection.py @@ -25,6 +25,7 @@ class BaseProjection(ABC): """Compute trust region loss between original and projected parameters.""" raise NotImplementedError + @partial(jax.jit, static_argnames=('self')) def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, mean_part: jnp.ndarray) -> jnp.ndarray: """Project mean based on the Mahalanobis objective and trust region. @@ -38,21 +39,14 @@ class BaseProjection(ABC): Projected mean that satisfies the trust region """ mask = mean_part > self.mean_bound - - # If nothing needs to be projected, skip computation - if not jnp.any(mask): - return mean - - # Compute projection factor - omega = jnp.ones(mean_part.shape, dtype=mean.dtype) - omega = jnp.where(mask, - jnp.sqrt(mean_part / self.mean_bound) - 1., - omega) + omega = jnp.ones_like(mean_part) + omega = jnp.where(mask, jnp.sqrt(mean_part / self.mean_bound) - 1., omega) omega = jnp.maximum(-omega, omega)[..., None] - # Project mean + # Use matrix operations instead of boolean indexing m = (mean + omega * old_mean) / (1. + omega + 1e-16) - return jnp.where(mask[..., None], m, mean) + mask_matrix = mask[..., None].astype(mean.dtype) + return mask_matrix * m + (1 - mask_matrix) * mean def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray: diff --git a/itpal_jax/frobenius_projection.py b/itpal_jax/frobenius_projection.py index e18820d..75615aa 100644 --- a/itpal_jax/frobenius_projection.py +++ b/itpal_jax/frobenius_projection.py @@ -105,8 +105,10 @@ class FrobeniusProjection(BaseProjection): if self.full_cov: new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \ (1. + eta + 1e-16)[..., None, None] - proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov) + mask_matrix = cov_mask[..., None, None].astype(cov.dtype) + proj_cov = mask_matrix * new_cov + (1 - mask_matrix) * cov return jnp.linalg.cholesky(proj_cov) else: new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None] - return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov) \ No newline at end of file + mask_matrix = cov_mask[..., None].astype(cov.dtype) + return mask_matrix * jnp.sqrt(new_cov) + (1 - mask_matrix) * cov \ No newline at end of file diff --git a/itpal_jax/kl_projection.py b/itpal_jax/kl_projection.py index 511f42b..97d62a2 100644 --- a/itpal_jax/kl_projection.py +++ b/itpal_jax/kl_projection.py @@ -160,28 +160,26 @@ class KLProjection(BaseProjection): old_cov = old_scale_or_tril ** 2 mask = cov_part > self.cov_bound - proj_scale_or_tril = scale_or_tril # Start with original scale - if mask.any(): - if self.full_cov: - proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) - is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) - proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril) - mask = mask & ~is_invalid - chol = jnp.linalg.cholesky(proj_cov) - proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril) - else: - proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) - is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) | - jnp.isinf(proj_cov.mean(axis=-1)) | - (proj_cov.min(axis=-1) < 0)) - proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril) - mask = mask & ~is_invalid - proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril) + # Always compute both branches and use matrix operations to select + if self.full_cov: + proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) + is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) + valid_mask = mask & ~is_invalid + + # Compute cholesky for all, let matrix ops handle selection + chol = jnp.linalg.cholesky(proj_cov) + mask_matrix = valid_mask[..., None, None].astype(scale_or_tril.dtype) + return mask_matrix * chol + (1 - mask_matrix) * scale_or_tril else: - proj_scale_or_tril = scale_or_tril - - return proj_scale_or_tril + proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) + is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) | + jnp.isinf(proj_cov.mean(axis=-1)) | + (proj_cov.min(axis=-1) < 0)) + valid_mask = mask & ~is_invalid + + mask_matrix = valid_mask[..., None].astype(scale_or_tril.dtype) + return mask_matrix * jnp.sqrt(proj_cov) + (1 - mask_matrix) * scale_or_tril def _validate_inputs(self, policy_params, old_policy_params): """Validate input parameters have correct format.""" diff --git a/itpal_jax/wasserstein_projection.py b/itpal_jax/wasserstein_projection.py index 56a1c09..0b4b083 100644 --- a/itpal_jax/wasserstein_projection.py +++ b/itpal_jax/wasserstein_projection.py @@ -91,17 +91,17 @@ class WassersteinProjection(BaseProjection): eta) eta = jnp.maximum(-eta, eta) - # Multiplicative update with correct broadcasting + # Multiplicative update with matrix operations if scale.ndim > 2: # Full covariance case new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \ (1. + eta + 1e-16)[..., None, None] - mask = cov_mask[..., None, None] + mask_matrix = cov_mask[..., None, None].astype(scale.dtype) + return mask_matrix * new_scale + (1 - mask_matrix) * scale else: # Diagonal case new_scale = (scale + eta[..., None] * old_scale) / \ (1. + eta + 1e-16)[..., None] - mask = cov_mask[..., None] - - return jnp.where(mask, new_scale, scale) + mask_matrix = cov_mask[..., None].astype(scale.dtype) + return mask_matrix * new_scale + (1 - mask_matrix) * scale @staticmethod @jax.jit