From 7fca6186d5c541438ac33b76c8b8c669ee535721 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 21 Dec 2024 19:21:24 +0100 Subject: [PATCH] jit wherever possible --- itpal_jax/base_projection.py | 50 ++++++++++++++++++++++++----- itpal_jax/frobenius_projection.py | 11 ++----- itpal_jax/identity_projection.py | 4 +++ itpal_jax/kl_projection.py | 12 +++---- itpal_jax/wasserstein_projection.py | 15 ++++----- 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/itpal_jax/base_projection.py b/itpal_jax/base_projection.py index 7c28ac5..a72ae16 100644 --- a/itpal_jax/base_projection.py +++ b/itpal_jax/base_projection.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Dict +import jax import jax.numpy as jnp +from functools import partial class BaseProjection(ABC): def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, @@ -8,22 +10,54 @@ class BaseProjection(ABC): self.trust_region_coeff = trust_region_coeff self.mean_bound = mean_bound self.cov_bound = cov_bound - self.full_cov = full_cov self.contextual_std = contextual_std + self.full_cov = full_cov @abstractmethod def project(self, policy_params: Dict[str, jnp.ndarray], old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: - """Project policy parameters. + """Project parameters to satisfy trust region constraints.""" + raise NotImplementedError + + @abstractmethod + def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], + proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + """Compute trust region loss between original and projected parameters.""" + raise NotImplementedError + + def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, + mean_part: jnp.ndarray) -> jnp.ndarray: + """Project mean based on the Mahalanobis objective and trust region. Args: - policy_params: Dictionary with: - - 'loc': mean parameters (batch_size, dim) - - 'scale': standard deviations (batch_size, dim) if full_cov=False - - 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True - old_policy_params: Same format as policy_params + mean: Current mean vectors + old_mean: Old mean vectors + mean_part: Mahalanobis/Euclidean distance between the two mean vectors + + Returns: + Projected mean that satisfies the trust region """ - pass + mask = mean_part > self.mean_bound + + # If nothing needs to be projected, skip computation + if not jnp.any(mask): + return mean + + # Compute projection factor + omega = jnp.ones(mean_part.shape, dtype=mean.dtype) + omega = jnp.where(mask, + jnp.sqrt(mean_part / self.mean_bound) - 1., + omega) + omega = jnp.maximum(-omega, omega)[..., None] + + # Project mean + m = (mean + omega * old_mean) / (1. + omega + 1e-16) + return jnp.where(mask[..., None], m, mean) + + def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, + cov_part: jnp.ndarray) -> jnp.ndarray: + """Project covariance parameters.""" + raise NotImplementedError def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Convert scale representation to covariance matrix.""" diff --git a/itpal_jax/frobenius_projection.py b/itpal_jax/frobenius_projection.py index 2faadfb..e18820d 100644 --- a/itpal_jax/frobenius_projection.py +++ b/itpal_jax/frobenius_projection.py @@ -2,6 +2,7 @@ import jax.numpy as jnp from .base_projection import BaseProjection from typing import Dict import jax +from functools import partial class FrobeniusProjection(BaseProjection): def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, @@ -47,6 +48,7 @@ class FrobeniusProjection(BaseProjection): else: return {"loc": proj_mean, "scale": scale_or_tril} + @partial(jax.jit, static_argnames=('self')) def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: mean = policy_params["loc"] @@ -60,6 +62,7 @@ class FrobeniusProjection(BaseProjection): return (mean_diff + cov_diff).mean() * self.trust_region_coeff + @partial(jax.jit, static_argnames=('self')) def _gaussian_frobenius(self, p, q): mean, cov = p old_mean, old_cov = q @@ -88,14 +91,6 @@ class FrobeniusProjection(BaseProjection): return mean_part, cov_part - def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, - mean_part: jnp.ndarray) -> jnp.ndarray: - diff = mean - old_mean - norm = jnp.sqrt(mean_part) - return jnp.where(norm > self.mean_bound, - old_mean + diff * self.mean_bound / norm[..., None], - mean) - def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray: batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1] diff --git a/itpal_jax/identity_projection.py b/itpal_jax/identity_projection.py index a27d202..467f473 100644 --- a/itpal_jax/identity_projection.py +++ b/itpal_jax/identity_projection.py @@ -1,6 +1,8 @@ import jax.numpy as jnp from .base_projection import BaseProjection from typing import Dict +import jax +from functools import partial class IdentityProjection(BaseProjection): def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, @@ -8,10 +10,12 @@ class IdentityProjection(BaseProjection): super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov) + @partial(jax.jit, static_argnames=('self')) def project(self, policy_params: Dict[str, jnp.ndarray], old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: return policy_params + @partial(jax.jit, static_argnames=('self')) def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: return jnp.array(0.0) \ No newline at end of file diff --git a/itpal_jax/kl_projection.py b/itpal_jax/kl_projection.py index f017d98..511f42b 100644 --- a/itpal_jax/kl_projection.py +++ b/itpal_jax/kl_projection.py @@ -86,6 +86,7 @@ class KLProjection(BaseProjection): else: return {"loc": proj_mean, "scale": proj_scale_or_tril} + @partial(jax.jit, static_argnames=('self')) def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Compute trust region loss between original and projected parameters.""" @@ -103,6 +104,7 @@ class KLProjection(BaseProjection): kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) return jnp.mean(kl) * self.trust_region_coeff + @partial(jax.jit, static_argnames=('self')) def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray], q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: mean, scale_or_tril = p @@ -127,6 +129,7 @@ class KLProjection(BaseProjection): return maha_part, cov_part + @partial(jax.jit, static_argnames=('self')) def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray: diff = x - y if self.full_cov: @@ -137,21 +140,17 @@ class KLProjection(BaseProjection): else: return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1) + @partial(jax.jit, static_argnames=('self')) def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray: if self.full_cov: return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1) else: return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1) + @partial(jax.jit, static_argnames=('self')) def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray: return jnp.sum(x ** 2, axis=(-2, -1)) - def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, - mean_part: jnp.ndarray) -> jnp.ndarray: - return old_mean + (mean - old_mean) * jnp.sqrt( - self.mean_bound / (mean_part + 1e-8) - )[..., None] - def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray: if self.full_cov: cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2)) @@ -161,6 +160,7 @@ class KLProjection(BaseProjection): old_cov = old_scale_or_tril ** 2 mask = cov_part > self.cov_bound + proj_scale_or_tril = scale_or_tril # Start with original scale if mask.any(): if self.full_cov: diff --git a/itpal_jax/wasserstein_projection.py b/itpal_jax/wasserstein_projection.py index ddf33d4..56a1c09 100644 --- a/itpal_jax/wasserstein_projection.py +++ b/itpal_jax/wasserstein_projection.py @@ -2,7 +2,9 @@ import jax.numpy as jnp from .base_projection import BaseProjection from typing import Dict, Tuple import jax +from functools import partial +@jax.jit def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray: """ 'Converts' scale_tril to scale_sqrt. @@ -52,6 +54,7 @@ class WassersteinProjection(BaseProjection): return {"loc": proj_mean, "scale": proj_scale} + @partial(jax.jit, static_argnames=('self')) def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: mean = policy_params["loc"] @@ -65,14 +68,6 @@ class WassersteinProjection(BaseProjection): w2 = mean_part + cov_part return w2.mean() * self.trust_region_coeff - def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray, - mean_part: jnp.ndarray) -> jnp.ndarray: - diff = mean - old_mean - norm = jnp.sqrt(mean_part) - return jnp.where(norm > self.mean_bound, - old_mean + diff * self.mean_bound / norm[..., None], - mean) - def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, scale_part: jnp.ndarray) -> jnp.ndarray: """Project scale parameters using multiplicative update. @@ -108,7 +103,9 @@ class WassersteinProjection(BaseProjection): return jnp.where(mask, new_scale, scale) - def _gaussian_wasserstein(self, p, q): + @staticmethod + @jax.jit + def _gaussian_wasserstein(p, q): mean, scale = p mean_other, scale_other = q