jit wherever possible

This commit is contained in:
Dominik Moritz Roth 2024-12-21 19:21:24 +01:00
parent 2e0ca977bc
commit 7fca6186d5
5 changed files with 61 additions and 31 deletions

View File

@ -1,6 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
import jax
import jax.numpy as jnp import jax.numpy as jnp
from functools import partial
class BaseProjection(ABC): class BaseProjection(ABC):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -8,22 +10,54 @@ class BaseProjection(ABC):
self.trust_region_coeff = trust_region_coeff self.trust_region_coeff = trust_region_coeff
self.mean_bound = mean_bound self.mean_bound = mean_bound
self.cov_bound = cov_bound self.cov_bound = cov_bound
self.full_cov = full_cov
self.contextual_std = contextual_std self.contextual_std = contextual_std
self.full_cov = full_cov
@abstractmethod @abstractmethod
def project(self, policy_params: Dict[str, jnp.ndarray], def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
"""Project policy parameters. """Project parameters to satisfy trust region constraints."""
raise NotImplementedError
@abstractmethod
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Compute trust region loss between original and projected parameters."""
raise NotImplementedError
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
"""Project mean based on the Mahalanobis objective and trust region.
Args: Args:
policy_params: Dictionary with: mean: Current mean vectors
- 'loc': mean parameters (batch_size, dim) old_mean: Old mean vectors
- 'scale': standard deviations (batch_size, dim) if full_cov=False mean_part: Mahalanobis/Euclidean distance between the two mean vectors
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
old_policy_params: Same format as policy_params Returns:
Projected mean that satisfies the trust region
""" """
pass mask = mean_part > self.mean_bound
# If nothing needs to be projected, skip computation
if not jnp.any(mask):
return mean
# Compute projection factor
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
omega = jnp.where(mask,
jnp.sqrt(mean_part / self.mean_bound) - 1.,
omega)
omega = jnp.maximum(-omega, omega)[..., None]
# Project mean
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
return jnp.where(mask[..., None], m, mean)
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
"""Project covariance parameters."""
raise NotImplementedError
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Convert scale representation to covariance matrix.""" """Convert scale representation to covariance matrix."""

View File

@ -2,6 +2,7 @@ import jax.numpy as jnp
from .base_projection import BaseProjection from .base_projection import BaseProjection
from typing import Dict from typing import Dict
import jax import jax
from functools import partial
class FrobeniusProjection(BaseProjection): class FrobeniusProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -47,6 +48,7 @@ class FrobeniusProjection(BaseProjection):
else: else:
return {"loc": proj_mean, "scale": scale_or_tril} return {"loc": proj_mean, "scale": scale_or_tril}
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean = policy_params["loc"] mean = policy_params["loc"]
@ -60,6 +62,7 @@ class FrobeniusProjection(BaseProjection):
return (mean_diff + cov_diff).mean() * self.trust_region_coeff return (mean_diff + cov_diff).mean() * self.trust_region_coeff
@partial(jax.jit, static_argnames=('self'))
def _gaussian_frobenius(self, p, q): def _gaussian_frobenius(self, p, q):
mean, cov = p mean, cov = p
old_mean, old_cov = q old_mean, old_cov = q
@ -88,14 +91,6 @@ class FrobeniusProjection(BaseProjection):
return mean_part, cov_part return mean_part, cov_part
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
diff = mean - old_mean
norm = jnp.sqrt(mean_part)
return jnp.where(norm > self.mean_bound,
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray: cov_part: jnp.ndarray) -> jnp.ndarray:
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1] batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]

View File

@ -1,6 +1,8 @@
import jax.numpy as jnp import jax.numpy as jnp
from .base_projection import BaseProjection from .base_projection import BaseProjection
from typing import Dict from typing import Dict
import jax
from functools import partial
class IdentityProjection(BaseProjection): class IdentityProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -8,10 +10,12 @@ class IdentityProjection(BaseProjection):
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov) cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
@partial(jax.jit, static_argnames=('self'))
def project(self, policy_params: Dict[str, jnp.ndarray], def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
return policy_params return policy_params
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return jnp.array(0.0) return jnp.array(0.0)

View File

@ -86,6 +86,7 @@ class KLProjection(BaseProjection):
else: else:
return {"loc": proj_mean, "scale": proj_scale_or_tril} return {"loc": proj_mean, "scale": proj_scale_or_tril}
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Compute trust region loss between original and projected parameters.""" """Compute trust region loss between original and projected parameters."""
@ -103,6 +104,7 @@ class KLProjection(BaseProjection):
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
return jnp.mean(kl) * self.trust_region_coeff return jnp.mean(kl) * self.trust_region_coeff
@partial(jax.jit, static_argnames=('self'))
def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray], def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray],
q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
mean, scale_or_tril = p mean, scale_or_tril = p
@ -127,6 +129,7 @@ class KLProjection(BaseProjection):
return maha_part, cov_part return maha_part, cov_part
@partial(jax.jit, static_argnames=('self'))
def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray: def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
diff = x - y diff = x - y
if self.full_cov: if self.full_cov:
@ -137,21 +140,17 @@ class KLProjection(BaseProjection):
else: else:
return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1) return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1)
@partial(jax.jit, static_argnames=('self'))
def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray: def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
if self.full_cov: if self.full_cov:
return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1) return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1)
else: else:
return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1) return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1)
@partial(jax.jit, static_argnames=('self'))
def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray: def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(x ** 2, axis=(-2, -1)) return jnp.sum(x ** 2, axis=(-2, -1))
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
return old_mean + (mean - old_mean) * jnp.sqrt(
self.mean_bound / (mean_part + 1e-8)
)[..., None]
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray: def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray:
if self.full_cov: if self.full_cov:
cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2)) cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2))
@ -161,6 +160,7 @@ class KLProjection(BaseProjection):
old_cov = old_scale_or_tril ** 2 old_cov = old_scale_or_tril ** 2
mask = cov_part > self.cov_bound mask = cov_part > self.cov_bound
proj_scale_or_tril = scale_or_tril # Start with original scale
if mask.any(): if mask.any():
if self.full_cov: if self.full_cov:

View File

@ -2,7 +2,9 @@ import jax.numpy as jnp
from .base_projection import BaseProjection from .base_projection import BaseProjection
from typing import Dict, Tuple from typing import Dict, Tuple
import jax import jax
from functools import partial
@jax.jit
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray: def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
""" """
'Converts' scale_tril to scale_sqrt. 'Converts' scale_tril to scale_sqrt.
@ -52,6 +54,7 @@ class WassersteinProjection(BaseProjection):
return {"loc": proj_mean, "scale": proj_scale} return {"loc": proj_mean, "scale": proj_scale}
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean = policy_params["loc"] mean = policy_params["loc"]
@ -65,14 +68,6 @@ class WassersteinProjection(BaseProjection):
w2 = mean_part + cov_part w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
diff = mean - old_mean
norm = jnp.sqrt(mean_part)
return jnp.where(norm > self.mean_bound,
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray: scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters using multiplicative update. """Project scale parameters using multiplicative update.
@ -108,7 +103,9 @@ class WassersteinProjection(BaseProjection):
return jnp.where(mask, new_scale, scale) return jnp.where(mask, new_scale, scale)
def _gaussian_wasserstein(self, p, q): @staticmethod
@jax.jit
def _gaussian_wasserstein(p, q):
mean, scale = p mean, scale = p
mean_other, scale_other = q mean_other, scale_other = q