jit wherever possible

This commit is contained in:
Dominik Moritz Roth 2024-12-21 19:21:24 +01:00
parent 2e0ca977bc
commit 7fca6186d5
5 changed files with 61 additions and 31 deletions

View File

@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import Dict
import jax
import jax.numpy as jnp
from functools import partial
class BaseProjection(ABC):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -8,22 +10,54 @@ class BaseProjection(ABC):
self.trust_region_coeff = trust_region_coeff
self.mean_bound = mean_bound
self.cov_bound = cov_bound
self.full_cov = full_cov
self.contextual_std = contextual_std
self.full_cov = full_cov
@abstractmethod
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
"""Project policy parameters.
"""Project parameters to satisfy trust region constraints."""
raise NotImplementedError
@abstractmethod
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Compute trust region loss between original and projected parameters."""
raise NotImplementedError
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
"""Project mean based on the Mahalanobis objective and trust region.
Args:
policy_params: Dictionary with:
- 'loc': mean parameters (batch_size, dim)
- 'scale': standard deviations (batch_size, dim) if full_cov=False
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
old_policy_params: Same format as policy_params
mean: Current mean vectors
old_mean: Old mean vectors
mean_part: Mahalanobis/Euclidean distance between the two mean vectors
Returns:
Projected mean that satisfies the trust region
"""
pass
mask = mean_part > self.mean_bound
# If nothing needs to be projected, skip computation
if not jnp.any(mask):
return mean
# Compute projection factor
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
omega = jnp.where(mask,
jnp.sqrt(mean_part / self.mean_bound) - 1.,
omega)
omega = jnp.maximum(-omega, omega)[..., None]
# Project mean
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
return jnp.where(mask[..., None], m, mean)
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
"""Project covariance parameters."""
raise NotImplementedError
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Convert scale representation to covariance matrix."""

View File

@ -2,6 +2,7 @@ import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict
import jax
from functools import partial
class FrobeniusProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -47,6 +48,7 @@ class FrobeniusProjection(BaseProjection):
else:
return {"loc": proj_mean, "scale": scale_or_tril}
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean = policy_params["loc"]
@ -60,6 +62,7 @@ class FrobeniusProjection(BaseProjection):
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
@partial(jax.jit, static_argnames=('self'))
def _gaussian_frobenius(self, p, q):
mean, cov = p
old_mean, old_cov = q
@ -88,14 +91,6 @@ class FrobeniusProjection(BaseProjection):
return mean_part, cov_part
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
diff = mean - old_mean
norm = jnp.sqrt(mean_part)
return jnp.where(norm > self.mean_bound,
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]

View File

@ -1,6 +1,8 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict
import jax
from functools import partial
class IdentityProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
@ -8,10 +10,12 @@ class IdentityProjection(BaseProjection):
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
@partial(jax.jit, static_argnames=('self'))
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
return policy_params
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return jnp.array(0.0)

View File

@ -86,6 +86,7 @@ class KLProjection(BaseProjection):
else:
return {"loc": proj_mean, "scale": proj_scale_or_tril}
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Compute trust region loss between original and projected parameters."""
@ -103,6 +104,7 @@ class KLProjection(BaseProjection):
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
return jnp.mean(kl) * self.trust_region_coeff
@partial(jax.jit, static_argnames=('self'))
def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray],
q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
mean, scale_or_tril = p
@ -127,6 +129,7 @@ class KLProjection(BaseProjection):
return maha_part, cov_part
@partial(jax.jit, static_argnames=('self'))
def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
diff = x - y
if self.full_cov:
@ -137,21 +140,17 @@ class KLProjection(BaseProjection):
else:
return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1)
@partial(jax.jit, static_argnames=('self'))
def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
if self.full_cov:
return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1)
else:
return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1)
@partial(jax.jit, static_argnames=('self'))
def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(x ** 2, axis=(-2, -1))
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
return old_mean + (mean - old_mean) * jnp.sqrt(
self.mean_bound / (mean_part + 1e-8)
)[..., None]
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray:
if self.full_cov:
cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2))
@ -161,6 +160,7 @@ class KLProjection(BaseProjection):
old_cov = old_scale_or_tril ** 2
mask = cov_part > self.cov_bound
proj_scale_or_tril = scale_or_tril # Start with original scale
if mask.any():
if self.full_cov:

View File

@ -2,7 +2,9 @@ import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict, Tuple
import jax
from functools import partial
@jax.jit
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
"""
'Converts' scale_tril to scale_sqrt.
@ -52,6 +54,7 @@ class WassersteinProjection(BaseProjection):
return {"loc": proj_mean, "scale": proj_scale}
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean = policy_params["loc"]
@ -65,14 +68,6 @@ class WassersteinProjection(BaseProjection):
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
diff = mean - old_mean
norm = jnp.sqrt(mean_part)
return jnp.where(norm > self.mean_bound,
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters using multiplicative update.
@ -108,7 +103,9 @@ class WassersteinProjection(BaseProjection):
return jnp.where(mask, new_scale, scale)
def _gaussian_wasserstein(self, p, q):
@staticmethod
@jax.jit
def _gaussian_wasserstein(p, q):
mean, scale = p
mean_other, scale_other = q