jit wherever possible
This commit is contained in:
parent
2e0ca977bc
commit
7fca6186d5
@ -1,6 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from functools import partial
|
||||
|
||||
class BaseProjection(ABC):
|
||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||
@ -8,22 +10,54 @@ class BaseProjection(ABC):
|
||||
self.trust_region_coeff = trust_region_coeff
|
||||
self.mean_bound = mean_bound
|
||||
self.cov_bound = cov_bound
|
||||
self.full_cov = full_cov
|
||||
self.contextual_std = contextual_std
|
||||
self.full_cov = full_cov
|
||||
|
||||
@abstractmethod
|
||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||
"""Project policy parameters.
|
||||
"""Project parameters to satisfy trust region constraints."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
"""Compute trust region loss between original and projected parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Project mean based on the Mahalanobis objective and trust region.
|
||||
|
||||
Args:
|
||||
policy_params: Dictionary with:
|
||||
- 'loc': mean parameters (batch_size, dim)
|
||||
- 'scale': standard deviations (batch_size, dim) if full_cov=False
|
||||
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
|
||||
old_policy_params: Same format as policy_params
|
||||
mean: Current mean vectors
|
||||
old_mean: Old mean vectors
|
||||
mean_part: Mahalanobis/Euclidean distance between the two mean vectors
|
||||
|
||||
Returns:
|
||||
Projected mean that satisfies the trust region
|
||||
"""
|
||||
pass
|
||||
mask = mean_part > self.mean_bound
|
||||
|
||||
# If nothing needs to be projected, skip computation
|
||||
if not jnp.any(mask):
|
||||
return mean
|
||||
|
||||
# Compute projection factor
|
||||
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
|
||||
omega = jnp.where(mask,
|
||||
jnp.sqrt(mean_part / self.mean_bound) - 1.,
|
||||
omega)
|
||||
omega = jnp.maximum(-omega, omega)[..., None]
|
||||
|
||||
# Project mean
|
||||
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
|
||||
return jnp.where(mask[..., None], m, mean)
|
||||
|
||||
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
|
||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Project covariance parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
"""Convert scale representation to covariance matrix."""
|
||||
|
@ -2,6 +2,7 @@ import jax.numpy as jnp
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
import jax
|
||||
from functools import partial
|
||||
|
||||
class FrobeniusProjection(BaseProjection):
|
||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||
@ -47,6 +48,7 @@ class FrobeniusProjection(BaseProjection):
|
||||
else:
|
||||
return {"loc": proj_mean, "scale": scale_or_tril}
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
mean = policy_params["loc"]
|
||||
@ -60,6 +62,7 @@ class FrobeniusProjection(BaseProjection):
|
||||
|
||||
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def _gaussian_frobenius(self, p, q):
|
||||
mean, cov = p
|
||||
old_mean, old_cov = q
|
||||
@ -88,14 +91,6 @@ class FrobeniusProjection(BaseProjection):
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||
diff = mean - old_mean
|
||||
norm = jnp.sqrt(mean_part)
|
||||
return jnp.where(norm > self.mean_bound,
|
||||
old_mean + diff * self.mean_bound / norm[..., None],
|
||||
mean)
|
||||
|
||||
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
|
||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
|
||||
|
@ -1,6 +1,8 @@
|
||||
import jax.numpy as jnp
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
import jax
|
||||
from functools import partial
|
||||
|
||||
class IdentityProjection(BaseProjection):
|
||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||
@ -8,10 +10,12 @@ class IdentityProjection(BaseProjection):
|
||||
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
||||
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||
return policy_params
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
return jnp.array(0.0)
|
@ -86,6 +86,7 @@ class KLProjection(BaseProjection):
|
||||
else:
|
||||
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
"""Compute trust region loss between original and projected parameters."""
|
||||
@ -103,6 +104,7 @@ class KLProjection(BaseProjection):
|
||||
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||
return jnp.mean(kl) * self.trust_region_coeff
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray],
|
||||
q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
mean, scale_or_tril = p
|
||||
@ -127,6 +129,7 @@ class KLProjection(BaseProjection):
|
||||
|
||||
return maha_part, cov_part
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
||||
diff = x - y
|
||||
if self.full_cov:
|
||||
@ -137,21 +140,17 @@ class KLProjection(BaseProjection):
|
||||
else:
|
||||
return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1)
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
||||
if self.full_cov:
|
||||
return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1)
|
||||
else:
|
||||
return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1)
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||
return jnp.sum(x ** 2, axis=(-2, -1))
|
||||
|
||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||
return old_mean + (mean - old_mean) * jnp.sqrt(
|
||||
self.mean_bound / (mean_part + 1e-8)
|
||||
)[..., None]
|
||||
|
||||
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||
if self.full_cov:
|
||||
cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2))
|
||||
@ -161,6 +160,7 @@ class KLProjection(BaseProjection):
|
||||
old_cov = old_scale_or_tril ** 2
|
||||
|
||||
mask = cov_part > self.cov_bound
|
||||
proj_scale_or_tril = scale_or_tril # Start with original scale
|
||||
|
||||
if mask.any():
|
||||
if self.full_cov:
|
||||
|
@ -2,7 +2,9 @@ import jax.numpy as jnp
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict, Tuple
|
||||
import jax
|
||||
from functools import partial
|
||||
|
||||
@jax.jit
|
||||
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
||||
"""
|
||||
'Converts' scale_tril to scale_sqrt.
|
||||
@ -52,6 +54,7 @@ class WassersteinProjection(BaseProjection):
|
||||
|
||||
return {"loc": proj_mean, "scale": proj_scale}
|
||||
|
||||
@partial(jax.jit, static_argnames=('self'))
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
mean = policy_params["loc"]
|
||||
@ -65,14 +68,6 @@ class WassersteinProjection(BaseProjection):
|
||||
w2 = mean_part + cov_part
|
||||
return w2.mean() * self.trust_region_coeff
|
||||
|
||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||
diff = mean - old_mean
|
||||
norm = jnp.sqrt(mean_part)
|
||||
return jnp.where(norm > self.mean_bound,
|
||||
old_mean + diff * self.mean_bound / norm[..., None],
|
||||
mean)
|
||||
|
||||
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
||||
scale_part: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Project scale parameters using multiplicative update.
|
||||
@ -108,7 +103,9 @@ class WassersteinProjection(BaseProjection):
|
||||
|
||||
return jnp.where(mask, new_scale, scale)
|
||||
|
||||
def _gaussian_wasserstein(self, p, q):
|
||||
@staticmethod
|
||||
@jax.jit
|
||||
def _gaussian_wasserstein(p, q):
|
||||
mean, scale = p
|
||||
mean_other, scale_other = q
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user