jit wherever possible
This commit is contained in:
parent
2e0ca977bc
commit
7fca6186d5
@ -1,6 +1,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
class BaseProjection(ABC):
|
class BaseProjection(ABC):
|
||||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
@ -8,22 +10,54 @@ class BaseProjection(ABC):
|
|||||||
self.trust_region_coeff = trust_region_coeff
|
self.trust_region_coeff = trust_region_coeff
|
||||||
self.mean_bound = mean_bound
|
self.mean_bound = mean_bound
|
||||||
self.cov_bound = cov_bound
|
self.cov_bound = cov_bound
|
||||||
self.full_cov = full_cov
|
|
||||||
self.contextual_std = contextual_std
|
self.contextual_std = contextual_std
|
||||||
|
self.full_cov = full_cov
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
"""Project policy parameters.
|
"""Project parameters to satisfy trust region constraints."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
"""Compute trust region loss between original and projected parameters."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||||
|
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""Project mean based on the Mahalanobis objective and trust region.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_params: Dictionary with:
|
mean: Current mean vectors
|
||||||
- 'loc': mean parameters (batch_size, dim)
|
old_mean: Old mean vectors
|
||||||
- 'scale': standard deviations (batch_size, dim) if full_cov=False
|
mean_part: Mahalanobis/Euclidean distance between the two mean vectors
|
||||||
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
|
|
||||||
old_policy_params: Same format as policy_params
|
Returns:
|
||||||
|
Projected mean that satisfies the trust region
|
||||||
"""
|
"""
|
||||||
pass
|
mask = mean_part > self.mean_bound
|
||||||
|
|
||||||
|
# If nothing needs to be projected, skip computation
|
||||||
|
if not jnp.any(mask):
|
||||||
|
return mean
|
||||||
|
|
||||||
|
# Compute projection factor
|
||||||
|
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
|
||||||
|
omega = jnp.where(mask,
|
||||||
|
jnp.sqrt(mean_part / self.mean_bound) - 1.,
|
||||||
|
omega)
|
||||||
|
omega = jnp.maximum(-omega, omega)[..., None]
|
||||||
|
|
||||||
|
# Project mean
|
||||||
|
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
|
||||||
|
return jnp.where(mask[..., None], m, mean)
|
||||||
|
|
||||||
|
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
|
||||||
|
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""Project covariance parameters."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
"""Convert scale representation to covariance matrix."""
|
"""Convert scale representation to covariance matrix."""
|
||||||
|
@ -2,6 +2,7 @@ import jax.numpy as jnp
|
|||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import jax
|
import jax
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
class FrobeniusProjection(BaseProjection):
|
class FrobeniusProjection(BaseProjection):
|
||||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
@ -47,6 +48,7 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
else:
|
else:
|
||||||
return {"loc": proj_mean, "scale": scale_or_tril}
|
return {"loc": proj_mean, "scale": scale_or_tril}
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
mean = policy_params["loc"]
|
mean = policy_params["loc"]
|
||||||
@ -60,6 +62,7 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
|
|
||||||
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
|
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def _gaussian_frobenius(self, p, q):
|
def _gaussian_frobenius(self, p, q):
|
||||||
mean, cov = p
|
mean, cov = p
|
||||||
old_mean, old_cov = q
|
old_mean, old_cov = q
|
||||||
@ -88,14 +91,6 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
|
|
||||||
return mean_part, cov_part
|
return mean_part, cov_part
|
||||||
|
|
||||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
|
||||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
|
||||||
diff = mean - old_mean
|
|
||||||
norm = jnp.sqrt(mean_part)
|
|
||||||
return jnp.where(norm > self.mean_bound,
|
|
||||||
old_mean + diff * self.mean_bound / norm[..., None],
|
|
||||||
mean)
|
|
||||||
|
|
||||||
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
|
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
|
||||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
|
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
import jax
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
class IdentityProjection(BaseProjection):
|
class IdentityProjection(BaseProjection):
|
||||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
@ -8,10 +10,12 @@ class IdentityProjection(BaseProjection):
|
|||||||
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
||||||
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
|
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
return policy_params
|
return policy_params
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
return jnp.array(0.0)
|
return jnp.array(0.0)
|
@ -86,6 +86,7 @@ class KLProjection(BaseProjection):
|
|||||||
else:
|
else:
|
||||||
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
"""Compute trust region loss between original and projected parameters."""
|
"""Compute trust region loss between original and projected parameters."""
|
||||||
@ -103,6 +104,7 @@ class KLProjection(BaseProjection):
|
|||||||
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||||
return jnp.mean(kl) * self.trust_region_coeff
|
return jnp.mean(kl) * self.trust_region_coeff
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray],
|
def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray],
|
||||||
q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
mean, scale_or_tril = p
|
mean, scale_or_tril = p
|
||||||
@ -127,6 +129,7 @@ class KLProjection(BaseProjection):
|
|||||||
|
|
||||||
return maha_part, cov_part
|
return maha_part, cov_part
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
||||||
diff = x - y
|
diff = x - y
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
@ -137,21 +140,17 @@ class KLProjection(BaseProjection):
|
|||||||
else:
|
else:
|
||||||
return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1)
|
return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1)
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1)
|
return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1)
|
||||||
else:
|
else:
|
||||||
return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1)
|
return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1)
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray:
|
def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||||
return jnp.sum(x ** 2, axis=(-2, -1))
|
return jnp.sum(x ** 2, axis=(-2, -1))
|
||||||
|
|
||||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
|
||||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
|
||||||
return old_mean + (mean - old_mean) * jnp.sqrt(
|
|
||||||
self.mean_bound / (mean_part + 1e-8)
|
|
||||||
)[..., None]
|
|
||||||
|
|
||||||
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray:
|
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2))
|
cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2))
|
||||||
@ -161,6 +160,7 @@ class KLProjection(BaseProjection):
|
|||||||
old_cov = old_scale_or_tril ** 2
|
old_cov = old_scale_or_tril ** 2
|
||||||
|
|
||||||
mask = cov_part > self.cov_bound
|
mask = cov_part > self.cov_bound
|
||||||
|
proj_scale_or_tril = scale_or_tril # Start with original scale
|
||||||
|
|
||||||
if mask.any():
|
if mask.any():
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
|
@ -2,7 +2,9 @@ import jax.numpy as jnp
|
|||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
import jax
|
import jax
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""
|
"""
|
||||||
'Converts' scale_tril to scale_sqrt.
|
'Converts' scale_tril to scale_sqrt.
|
||||||
@ -52,6 +54,7 @@ class WassersteinProjection(BaseProjection):
|
|||||||
|
|
||||||
return {"loc": proj_mean, "scale": proj_scale}
|
return {"loc": proj_mean, "scale": proj_scale}
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('self'))
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
mean = policy_params["loc"]
|
mean = policy_params["loc"]
|
||||||
@ -65,14 +68,6 @@ class WassersteinProjection(BaseProjection):
|
|||||||
w2 = mean_part + cov_part
|
w2 = mean_part + cov_part
|
||||||
return w2.mean() * self.trust_region_coeff
|
return w2.mean() * self.trust_region_coeff
|
||||||
|
|
||||||
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
|
||||||
mean_part: jnp.ndarray) -> jnp.ndarray:
|
|
||||||
diff = mean - old_mean
|
|
||||||
norm = jnp.sqrt(mean_part)
|
|
||||||
return jnp.where(norm > self.mean_bound,
|
|
||||||
old_mean + diff * self.mean_bound / norm[..., None],
|
|
||||||
mean)
|
|
||||||
|
|
||||||
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
||||||
scale_part: jnp.ndarray) -> jnp.ndarray:
|
scale_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Project scale parameters using multiplicative update.
|
"""Project scale parameters using multiplicative update.
|
||||||
@ -108,7 +103,9 @@ class WassersteinProjection(BaseProjection):
|
|||||||
|
|
||||||
return jnp.where(mask, new_scale, scale)
|
return jnp.where(mask, new_scale, scale)
|
||||||
|
|
||||||
def _gaussian_wasserstein(self, p, q):
|
@staticmethod
|
||||||
|
@jax.jit
|
||||||
|
def _gaussian_wasserstein(p, q):
|
||||||
mean, scale = p
|
mean, scale = p
|
||||||
mean_other, scale_other = q
|
mean_other, scale_other = q
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user