Fixes for contextual KL
This commit is contained in:
parent
de2b9a10d6
commit
3e2b988a2f
@ -88,8 +88,18 @@ class KLProjection(BaseProjection):
|
|||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
|
"""Compute trust region loss between original and projected parameters."""
|
||||||
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params["scale"]
|
# Get the right scale parameter based on full_cov
|
||||||
|
mean = policy_params["loc"]
|
||||||
|
proj_mean = proj_policy_params["loc"]
|
||||||
|
|
||||||
|
if self.full_cov:
|
||||||
|
scale_or_tril = policy_params["scale_tril"]
|
||||||
|
proj_scale_or_tril = proj_policy_params["scale_tril"]
|
||||||
|
else:
|
||||||
|
scale_or_tril = policy_params["scale"]
|
||||||
|
proj_scale_or_tril = proj_policy_params["scale"]
|
||||||
|
|
||||||
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||||
return jnp.mean(kl) * self.trust_region_coeff
|
return jnp.mean(kl) * self.trust_region_coeff
|
||||||
|
|
||||||
@ -151,14 +161,12 @@ class KLProjection(BaseProjection):
|
|||||||
old_cov = old_scale_or_tril ** 2
|
old_cov = old_scale_or_tril ** 2
|
||||||
|
|
||||||
mask = cov_part > self.cov_bound
|
mask = cov_part > self.cov_bound
|
||||||
proj_scale_or_tril = jnp.zeros_like(scale_or_tril)
|
|
||||||
proj_scale_or_tril = jnp.where(~mask, scale_or_tril, proj_scale_or_tril)
|
|
||||||
|
|
||||||
if mask.any():
|
if mask.any():
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
||||||
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) & mask
|
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
|
||||||
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
|
proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril)
|
||||||
mask = mask & ~is_invalid
|
mask = mask & ~is_invalid
|
||||||
chol = jnp.linalg.cholesky(proj_cov)
|
chol = jnp.linalg.cholesky(proj_cov)
|
||||||
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
|
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
|
||||||
@ -166,10 +174,12 @@ class KLProjection(BaseProjection):
|
|||||||
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
||||||
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
||||||
jnp.isinf(proj_cov.mean(axis=-1)) |
|
jnp.isinf(proj_cov.mean(axis=-1)) |
|
||||||
(proj_cov.min(axis=-1) < 0)) & mask
|
(proj_cov.min(axis=-1) < 0))
|
||||||
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
|
proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
|
||||||
mask = mask & ~is_invalid
|
mask = mask & ~is_invalid
|
||||||
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), proj_scale_or_tril)
|
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril)
|
||||||
|
else:
|
||||||
|
proj_scale_or_tril = scale_or_tril
|
||||||
|
|
||||||
return proj_scale_or_tril
|
return proj_scale_or_tril
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user