Better jit (bool mask via matmul)

This commit is contained in:
Dominik Moritz Roth 2025-01-07 16:54:20 +01:00
parent 7fca6186d5
commit 4d6ed9b3ac
4 changed files with 33 additions and 39 deletions

View File

@ -25,6 +25,7 @@ class BaseProjection(ABC):
"""Compute trust region loss between original and projected parameters.""" """Compute trust region loss between original and projected parameters."""
raise NotImplementedError raise NotImplementedError
@partial(jax.jit, static_argnames=('self'))
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray: mean_part: jnp.ndarray) -> jnp.ndarray:
"""Project mean based on the Mahalanobis objective and trust region. """Project mean based on the Mahalanobis objective and trust region.
@ -38,21 +39,14 @@ class BaseProjection(ABC):
Projected mean that satisfies the trust region Projected mean that satisfies the trust region
""" """
mask = mean_part > self.mean_bound mask = mean_part > self.mean_bound
omega = jnp.ones_like(mean_part)
# If nothing needs to be projected, skip computation omega = jnp.where(mask, jnp.sqrt(mean_part / self.mean_bound) - 1., omega)
if not jnp.any(mask):
return mean
# Compute projection factor
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
omega = jnp.where(mask,
jnp.sqrt(mean_part / self.mean_bound) - 1.,
omega)
omega = jnp.maximum(-omega, omega)[..., None] omega = jnp.maximum(-omega, omega)[..., None]
# Project mean # Use matrix operations instead of boolean indexing
m = (mean + omega * old_mean) / (1. + omega + 1e-16) m = (mean + omega * old_mean) / (1. + omega + 1e-16)
return jnp.where(mask[..., None], m, mean) mask_matrix = mask[..., None].astype(mean.dtype)
return mask_matrix * m + (1 - mask_matrix) * mean
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray: cov_part: jnp.ndarray) -> jnp.ndarray:

View File

@ -105,8 +105,10 @@ class FrobeniusProjection(BaseProjection):
if self.full_cov: if self.full_cov:
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \ new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
(1. + eta + 1e-16)[..., None, None] (1. + eta + 1e-16)[..., None, None]
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov) mask_matrix = cov_mask[..., None, None].astype(cov.dtype)
proj_cov = mask_matrix * new_cov + (1 - mask_matrix) * cov
return jnp.linalg.cholesky(proj_cov) return jnp.linalg.cholesky(proj_cov)
else: else:
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None] new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov) mask_matrix = cov_mask[..., None].astype(cov.dtype)
return mask_matrix * jnp.sqrt(new_cov) + (1 - mask_matrix) * cov

View File

@ -160,28 +160,26 @@ class KLProjection(BaseProjection):
old_cov = old_scale_or_tril ** 2 old_cov = old_scale_or_tril ** 2
mask = cov_part > self.cov_bound mask = cov_part > self.cov_bound
proj_scale_or_tril = scale_or_tril # Start with original scale
if mask.any(): # Always compute both branches and use matrix operations to select
if self.full_cov: if self.full_cov:
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril) valid_mask = mask & ~is_invalid
mask = mask & ~is_invalid
chol = jnp.linalg.cholesky(proj_cov) # Compute cholesky for all, let matrix ops handle selection
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril) chol = jnp.linalg.cholesky(proj_cov)
else: mask_matrix = valid_mask[..., None, None].astype(scale_or_tril.dtype)
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) return mask_matrix * chol + (1 - mask_matrix) * scale_or_tril
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
jnp.isinf(proj_cov.mean(axis=-1)) |
(proj_cov.min(axis=-1) < 0))
proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
mask = mask & ~is_invalid
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril)
else: else:
proj_scale_or_tril = scale_or_tril proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
jnp.isinf(proj_cov.mean(axis=-1)) |
(proj_cov.min(axis=-1) < 0))
valid_mask = mask & ~is_invalid
return proj_scale_or_tril mask_matrix = valid_mask[..., None].astype(scale_or_tril.dtype)
return mask_matrix * jnp.sqrt(proj_cov) + (1 - mask_matrix) * scale_or_tril
def _validate_inputs(self, policy_params, old_policy_params): def _validate_inputs(self, policy_params, old_policy_params):
"""Validate input parameters have correct format.""" """Validate input parameters have correct format."""

View File

@ -91,17 +91,17 @@ class WassersteinProjection(BaseProjection):
eta) eta)
eta = jnp.maximum(-eta, eta) eta = jnp.maximum(-eta, eta)
# Multiplicative update with correct broadcasting # Multiplicative update with matrix operations
if scale.ndim > 2: # Full covariance case if scale.ndim > 2: # Full covariance case
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \ new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
(1. + eta + 1e-16)[..., None, None] (1. + eta + 1e-16)[..., None, None]
mask = cov_mask[..., None, None] mask_matrix = cov_mask[..., None, None].astype(scale.dtype)
return mask_matrix * new_scale + (1 - mask_matrix) * scale
else: # Diagonal case else: # Diagonal case
new_scale = (scale + eta[..., None] * old_scale) / \ new_scale = (scale + eta[..., None] * old_scale) / \
(1. + eta + 1e-16)[..., None] (1. + eta + 1e-16)[..., None]
mask = cov_mask[..., None] mask_matrix = cov_mask[..., None].astype(scale.dtype)
return mask_matrix * new_scale + (1 - mask_matrix) * scale
return jnp.where(mask, new_scale, scale)
@staticmethod @staticmethod
@jax.jit @jax.jit