Better jit (bool mask via matmul)
This commit is contained in:
parent
7fca6186d5
commit
4d6ed9b3ac
@ -25,6 +25,7 @@ class BaseProjection(ABC):
|
|||||||
"""Compute trust region loss between original and projected parameters."""
|
"""Compute trust region loss between original and projected parameters."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Project mean based on the Mahalanobis objective and trust region.
|
"""Project mean based on the Mahalanobis objective and trust region.
|
||||||
@ -38,21 +39,14 @@ class BaseProjection(ABC):
|
|||||||
Projected mean that satisfies the trust region
|
Projected mean that satisfies the trust region
|
||||||
"""
|
"""
|
||||||
mask = mean_part > self.mean_bound
|
mask = mean_part > self.mean_bound
|
||||||
|
omega = jnp.ones_like(mean_part)
|
||||||
# If nothing needs to be projected, skip computation
|
omega = jnp.where(mask, jnp.sqrt(mean_part / self.mean_bound) - 1., omega)
|
||||||
if not jnp.any(mask):
|
|
||||||
return mean
|
|
||||||
|
|
||||||
# Compute projection factor
|
|
||||||
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
|
|
||||||
omega = jnp.where(mask,
|
|
||||||
jnp.sqrt(mean_part / self.mean_bound) - 1.,
|
|
||||||
omega)
|
|
||||||
omega = jnp.maximum(-omega, omega)[..., None]
|
omega = jnp.maximum(-omega, omega)[..., None]
|
||||||
|
|
||||||
# Project mean
|
# Use matrix operations instead of boolean indexing
|
||||||
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
|
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
|
||||||
return jnp.where(mask[..., None], m, mean)
|
mask_matrix = mask[..., None].astype(mean.dtype)
|
||||||
|
return mask_matrix * m + (1 - mask_matrix) * mean
|
||||||
|
|
||||||
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
|
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
|
||||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
@ -105,8 +105,10 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
|
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
|
||||||
(1. + eta + 1e-16)[..., None, None]
|
(1. + eta + 1e-16)[..., None, None]
|
||||||
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
|
mask_matrix = cov_mask[..., None, None].astype(cov.dtype)
|
||||||
|
proj_cov = mask_matrix * new_cov + (1 - mask_matrix) * cov
|
||||||
return jnp.linalg.cholesky(proj_cov)
|
return jnp.linalg.cholesky(proj_cov)
|
||||||
else:
|
else:
|
||||||
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
||||||
return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)
|
mask_matrix = cov_mask[..., None].astype(cov.dtype)
|
||||||
|
return mask_matrix * jnp.sqrt(new_cov) + (1 - mask_matrix) * cov
|
@ -160,28 +160,26 @@ class KLProjection(BaseProjection):
|
|||||||
old_cov = old_scale_or_tril ** 2
|
old_cov = old_scale_or_tril ** 2
|
||||||
|
|
||||||
mask = cov_part > self.cov_bound
|
mask = cov_part > self.cov_bound
|
||||||
proj_scale_or_tril = scale_or_tril # Start with original scale
|
|
||||||
|
|
||||||
if mask.any():
|
# Always compute both branches and use matrix operations to select
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
||||||
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
|
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
|
||||||
proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril)
|
valid_mask = mask & ~is_invalid
|
||||||
mask = mask & ~is_invalid
|
|
||||||
|
# Compute cholesky for all, let matrix ops handle selection
|
||||||
chol = jnp.linalg.cholesky(proj_cov)
|
chol = jnp.linalg.cholesky(proj_cov)
|
||||||
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
|
mask_matrix = valid_mask[..., None, None].astype(scale_or_tril.dtype)
|
||||||
|
return mask_matrix * chol + (1 - mask_matrix) * scale_or_tril
|
||||||
else:
|
else:
|
||||||
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
||||||
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
||||||
jnp.isinf(proj_cov.mean(axis=-1)) |
|
jnp.isinf(proj_cov.mean(axis=-1)) |
|
||||||
(proj_cov.min(axis=-1) < 0))
|
(proj_cov.min(axis=-1) < 0))
|
||||||
proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
|
valid_mask = mask & ~is_invalid
|
||||||
mask = mask & ~is_invalid
|
|
||||||
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril)
|
|
||||||
else:
|
|
||||||
proj_scale_or_tril = scale_or_tril
|
|
||||||
|
|
||||||
return proj_scale_or_tril
|
mask_matrix = valid_mask[..., None].astype(scale_or_tril.dtype)
|
||||||
|
return mask_matrix * jnp.sqrt(proj_cov) + (1 - mask_matrix) * scale_or_tril
|
||||||
|
|
||||||
def _validate_inputs(self, policy_params, old_policy_params):
|
def _validate_inputs(self, policy_params, old_policy_params):
|
||||||
"""Validate input parameters have correct format."""
|
"""Validate input parameters have correct format."""
|
||||||
|
@ -91,17 +91,17 @@ class WassersteinProjection(BaseProjection):
|
|||||||
eta)
|
eta)
|
||||||
eta = jnp.maximum(-eta, eta)
|
eta = jnp.maximum(-eta, eta)
|
||||||
|
|
||||||
# Multiplicative update with correct broadcasting
|
# Multiplicative update with matrix operations
|
||||||
if scale.ndim > 2: # Full covariance case
|
if scale.ndim > 2: # Full covariance case
|
||||||
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
|
new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
|
||||||
(1. + eta + 1e-16)[..., None, None]
|
(1. + eta + 1e-16)[..., None, None]
|
||||||
mask = cov_mask[..., None, None]
|
mask_matrix = cov_mask[..., None, None].astype(scale.dtype)
|
||||||
|
return mask_matrix * new_scale + (1 - mask_matrix) * scale
|
||||||
else: # Diagonal case
|
else: # Diagonal case
|
||||||
new_scale = (scale + eta[..., None] * old_scale) / \
|
new_scale = (scale + eta[..., None] * old_scale) / \
|
||||||
(1. + eta + 1e-16)[..., None]
|
(1. + eta + 1e-16)[..., None]
|
||||||
mask = cov_mask[..., None]
|
mask_matrix = cov_mask[..., None].astype(scale.dtype)
|
||||||
|
return mask_matrix * new_scale + (1 - mask_matrix) * scale
|
||||||
return jnp.where(mask, new_scale, scale)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@jax.jit
|
@jax.jit
|
||||||
|
Loading…
Reference in New Issue
Block a user