revert kl, cxant kit compile c-binding

This commit is contained in:
Dominik Moritz Roth 2025-01-07 18:23:50 +01:00
parent 4d6ed9b3ac
commit 404320c5cc

View File

@ -160,26 +160,28 @@ class KLProjection(BaseProjection):
old_cov = old_scale_or_tril ** 2 old_cov = old_scale_or_tril ** 2
mask = cov_part > self.cov_bound mask = cov_part > self.cov_bound
proj_scale_or_tril = scale_or_tril # Start with original scale
# Always compute both branches and use matrix operations to select if mask.any():
if self.full_cov: if self.full_cov:
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
valid_mask = mask & ~is_invalid proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril)
mask = mask & ~is_invalid
# Compute cholesky for all, let matrix ops handle selection chol = jnp.linalg.cholesky(proj_cov)
chol = jnp.linalg.cholesky(proj_cov) proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
mask_matrix = valid_mask[..., None, None].astype(scale_or_tril.dtype) else:
return mask_matrix * chol + (1 - mask_matrix) * scale_or_tril proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
jnp.isinf(proj_cov.mean(axis=-1)) |
(proj_cov.min(axis=-1) < 0))
proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
mask = mask & ~is_invalid
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril)
else: else:
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) proj_scale_or_tril = scale_or_tril
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
jnp.isinf(proj_cov.mean(axis=-1)) |
(proj_cov.min(axis=-1) < 0))
valid_mask = mask & ~is_invalid
mask_matrix = valid_mask[..., None].astype(scale_or_tril.dtype) return proj_scale_or_tril
return mask_matrix * jnp.sqrt(proj_cov) + (1 - mask_matrix) * scale_or_tril
def _validate_inputs(self, policy_params, old_policy_params): def _validate_inputs(self, policy_params, old_policy_params):
"""Validate input parameters have correct format.""" """Validate input parameters have correct format."""