Compare commits
No commits in common. "2e0ca977bc143d5a729876ffad6c08ffab831ba5" and "de2b9a10d6f807a40821fbea42aca3d4eb2356d7" have entirely different histories.
2e0ca977bc
...
de2b9a10d6
@ -12,7 +12,7 @@ JAX bindings and native implementations of differentiable trust region projectio
|
||||
- Multiple projection types:
|
||||
- KL (Kullback-Leibler divergence)
|
||||
- Wasserstein (only diagonal covariance)
|
||||
- Frobenius (wip, problem with cov projections)
|
||||
- Frobenius (wip, not tested)
|
||||
- Identity (no projection)
|
||||
- Support for both diagonal and full covariance Gaussians (induced from cholesky decomposition)
|
||||
- Contextual and non-contextual standard deviations (non-contextual means all standard deviations in batch are expected to be the same)
|
||||
|
@ -88,18 +88,8 @@ class KLProjection(BaseProjection):
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
"""Compute trust region loss between original and projected parameters."""
|
||||
# Get the right scale parameter based on full_cov
|
||||
mean = policy_params["loc"]
|
||||
proj_mean = proj_policy_params["loc"]
|
||||
|
||||
if self.full_cov:
|
||||
scale_or_tril = policy_params["scale_tril"]
|
||||
proj_scale_or_tril = proj_policy_params["scale_tril"]
|
||||
else:
|
||||
scale_or_tril = policy_params["scale"]
|
||||
proj_scale_or_tril = proj_policy_params["scale"]
|
||||
|
||||
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
|
||||
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params["scale"]
|
||||
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||
return jnp.mean(kl) * self.trust_region_coeff
|
||||
|
||||
@ -161,12 +151,14 @@ class KLProjection(BaseProjection):
|
||||
old_cov = old_scale_or_tril ** 2
|
||||
|
||||
mask = cov_part > self.cov_bound
|
||||
|
||||
proj_scale_or_tril = jnp.zeros_like(scale_or_tril)
|
||||
proj_scale_or_tril = jnp.where(~mask, scale_or_tril, proj_scale_or_tril)
|
||||
|
||||
if mask.any():
|
||||
if self.full_cov:
|
||||
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
||||
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
|
||||
proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril)
|
||||
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) & mask
|
||||
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
|
||||
mask = mask & ~is_invalid
|
||||
chol = jnp.linalg.cholesky(proj_cov)
|
||||
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
|
||||
@ -174,12 +166,10 @@ class KLProjection(BaseProjection):
|
||||
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
||||
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
||||
jnp.isinf(proj_cov.mean(axis=-1)) |
|
||||
(proj_cov.min(axis=-1) < 0))
|
||||
proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
|
||||
(proj_cov.min(axis=-1) < 0)) & mask
|
||||
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
|
||||
mask = mask & ~is_invalid
|
||||
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril)
|
||||
else:
|
||||
proj_scale_or_tril = scale_or_tril
|
||||
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), proj_scale_or_tril)
|
||||
|
||||
return proj_scale_or_tril
|
||||
|
||||
|
@ -151,10 +151,6 @@ def test_full_covariance_projection(ProjectionClass):
|
||||
eigvals = jnp.linalg.eigvalsh(cov)
|
||||
assert jnp.all(eigvals > 0)
|
||||
|
||||
# Check trust region loss computation works
|
||||
tr_loss = proj.get_trust_region_loss(params, proj_params)
|
||||
assert jnp.isfinite(tr_loss)
|
||||
|
||||
# Only check KL bounds for KL projection
|
||||
if ProjectionClass in [KLProjection]:
|
||||
kl = compute_gaussian_kl(proj_params, old_params, full_cov=True)
|
||||
|
Loading…
Reference in New Issue
Block a user