diff --git a/itpal_jax/kl_projection.py b/itpal_jax/kl_projection.py index 97d62a2..511f42b 100644 --- a/itpal_jax/kl_projection.py +++ b/itpal_jax/kl_projection.py @@ -160,26 +160,28 @@ class KLProjection(BaseProjection): old_cov = old_scale_or_tril ** 2 mask = cov_part > self.cov_bound + proj_scale_or_tril = scale_or_tril # Start with original scale - # Always compute both branches and use matrix operations to select - if self.full_cov: - proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) - is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) - valid_mask = mask & ~is_invalid - - # Compute cholesky for all, let matrix ops handle selection - chol = jnp.linalg.cholesky(proj_cov) - mask_matrix = valid_mask[..., None, None].astype(scale_or_tril.dtype) - return mask_matrix * chol + (1 - mask_matrix) * scale_or_tril + if mask.any(): + if self.full_cov: + proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) + is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) + proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril) + mask = mask & ~is_invalid + chol = jnp.linalg.cholesky(proj_cov) + proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril) + else: + proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) + is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) | + jnp.isinf(proj_cov.mean(axis=-1)) | + (proj_cov.min(axis=-1) < 0)) + proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril) + mask = mask & ~is_invalid + proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril) else: - proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) - is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) | - jnp.isinf(proj_cov.mean(axis=-1)) | - (proj_cov.min(axis=-1) < 0)) - valid_mask = mask & ~is_invalid - - mask_matrix = valid_mask[..., None].astype(scale_or_tril.dtype) - return mask_matrix * jnp.sqrt(proj_cov) + (1 - mask_matrix) * scale_or_tril + proj_scale_or_tril = scale_or_tril + + return proj_scale_or_tril def _validate_inputs(self, policy_params, old_policy_params): """Validate input parameters have correct format."""