diff --git a/itpal_jax/kl_projection.py b/itpal_jax/kl_projection.py index 4a66403..f017d98 100644 --- a/itpal_jax/kl_projection.py +++ b/itpal_jax/kl_projection.py @@ -88,8 +88,18 @@ class KLProjection(BaseProjection): def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: - mean, scale_or_tril = policy_params["loc"], policy_params["scale"] - proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params["scale"] + """Compute trust region loss between original and projected parameters.""" + # Get the right scale parameter based on full_cov + mean = policy_params["loc"] + proj_mean = proj_policy_params["loc"] + + if self.full_cov: + scale_or_tril = policy_params["scale_tril"] + proj_scale_or_tril = proj_policy_params["scale_tril"] + else: + scale_or_tril = policy_params["scale"] + proj_scale_or_tril = proj_policy_params["scale"] + kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) return jnp.mean(kl) * self.trust_region_coeff @@ -151,14 +161,12 @@ class KLProjection(BaseProjection): old_cov = old_scale_or_tril ** 2 mask = cov_part > self.cov_bound - proj_scale_or_tril = jnp.zeros_like(scale_or_tril) - proj_scale_or_tril = jnp.where(~mask, scale_or_tril, proj_scale_or_tril) - + if mask.any(): if self.full_cov: proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) - is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) & mask - proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril) + is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) + proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril) mask = mask & ~is_invalid chol = jnp.linalg.cholesky(proj_cov) proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril) @@ -166,10 +174,12 @@ class KLProjection(BaseProjection): proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) | jnp.isinf(proj_cov.mean(axis=-1)) | - (proj_cov.min(axis=-1) < 0)) & mask - proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril) + (proj_cov.min(axis=-1) < 0)) + proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril) mask = mask & ~is_invalid - proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), proj_scale_or_tril) + proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril) + else: + proj_scale_or_tril = scale_or_tril return proj_scale_or_tril