to git we go

This commit is contained in:
Dominik Moritz Roth 2024-12-11 18:33:40 +01:00
commit 9f85217a47
8 changed files with 526 additions and 0 deletions

0
README.md Normal file
View File

24
itpal_jax/__init__.py Normal file
View File

@ -0,0 +1,24 @@
"""
JAX implementations of various policy projection methods.
Available projections:
- BaseProjection: Abstract base class for projections
- IdentityProjection: Simple identity mapping
- FrobeniusProjection: Frobenius norm-based projection
- WassersteinProjection: Wasserstein distance-based projection
- KLProjection: KL divergence-based projection (requires C++ backend)
"""
from .base_projection import BaseProjection
from .identity_projection import IdentityProjection
from .frobenius_projection import FrobeniusProjection
from .wasserstein_projection import WassersteinProjection
from .kl_projection import KLProjection
__all__ = [
'BaseProjection',
'IdentityProjection',
'FrobeniusProjection',
'WassersteinProjection',
'KLProjection',
]

View File

@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from typing import Dict
import jax.numpy as jnp
class BaseProjection(ABC):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
self.trust_region_coeff = trust_region_coeff
self.mean_bound = mean_bound
self.cov_bound = cov_bound
self.full_cov = full_cov
self.contextual_std = contextual_std
@abstractmethod
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
pass
@abstractmethod
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
pass
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
if not self.full_cov:
return jnp.diag(params["scale"] ** 2)
else:
scale_tril = params["scale_tril"]
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray:
if not self.full_cov:
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1))
else:
return jnp.linalg.cholesky(cov)

View File

@ -0,0 +1,20 @@
from .base_projection import BaseProjection
from functools import partial
import jax.numpy as jnp
from typing import Dict
class ExceptionProjection(BaseProjection):
def __init__(self, msg, *args, **kwargs):
self.msg = msg
raise Exception(msg)
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
pass
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
pass
def makeExceptionProjection(msg: str):
return partial(ExceptionProjection, msg)

View File

@ -0,0 +1,78 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict
class FrobeniusProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, scale_prec: bool = False,
contextual_std: bool = True, full_cov: bool = False):
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
cov = self._calc_covariance(policy_params)
old_cov = self._calc_covariance(old_policy_params)
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part)
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
return {"loc": proj_mean, "scale": scale_or_scale_tril}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"]
cov = self._calc_covariance(policy_params)
proj_cov = self._calc_covariance(proj_policy_params)
mean_diff = jnp.sum(jnp.square(mean - proj_mean), axis=-1)
cov_diff = jnp.sum(jnp.square(cov - proj_cov), axis=(-2, -1))
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
def _gaussian_frobenius(self, p, q):
mean, cov = p
old_mean, old_cov = q
if self.scale_prec:
prec_old = jnp.linalg.inv(old_cov)
mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
else:
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
return mean_part, cov_part
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
diff = mean - old_mean
norm = jnp.sqrt(mean_part)
return jnp.where(norm > self.mean_bound,
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
batch_shape = cov.shape[:-2]
cov_mask = cov_part > self.cov_bound
eta = jnp.ones(batch_shape, dtype=cov.dtype)
eta = jnp.where(cov_mask,
jnp.sqrt(cov_part / self.cov_bound) - 1.,
eta)
eta = jnp.maximum(-eta, eta)
new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
return proj_cov

View File

@ -0,0 +1,17 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict
class IdentityProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
return policy_params
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return jnp.array(0.0)

244
itpal_jax/kl_projection.py Normal file
View File

@ -0,0 +1,244 @@
try:
import cpp_projection
cpp_projection_available = True
except ImportError:
cpp_projection_available = False
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Dict, Tuple, Any
from .base_projection import BaseProjection
from .exception_projection import makeExceptionProjection
MAX_EVAL = 1000
class KLProjection(BaseProjection):
"""KL divergence-based projection for Gaussian policies.
This class implements KL divergence projection using JAX, with C++ backend
for efficient projection operations. It supports both diagonal and full
covariance matrices.
Args:
trust_region_coeff (float): Coefficient for trust region loss
mean_bound (float): Bound for mean projection
cov_bound (float): Bound for covariance projection
contextual_std (bool): Whether to use contextual standard deviations
"""
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
self._validate_inputs(policy_params, old_policy_params)
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"]
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril),
(old_mean, old_scale_or_tril))
if not self.contextual_std:
scale_or_tril = scale_or_tril[:1]
old_scale_or_tril = old_scale_or_tril[:1]
cov_part = cov_part[:1]
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
if not self.contextual_std:
proj_scale_or_tril = jnp.broadcast_to(
proj_scale_or_tril,
(mean.shape[0],) + proj_scale_or_tril.shape[1:]
)
return {"loc": proj_mean, "scale": proj_scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params["scale"]
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
return jnp.mean(kl) * self.trust_region_coeff
def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray],
q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
mean, scale_or_tril = p
mean_other, scale_or_tril_other = q
k = mean.shape[-1]
maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
det_term = self._log_determinant(scale_or_tril)
det_term_other = self._log_determinant(scale_or_tril_other)
if self.full_cov:
trace_part = self._batched_trace_square(
jax.scipy.linalg.solve_triangular(
scale_or_tril_other, scale_or_tril, lower=True
)
)
else:
trace_part = jnp.sum((scale_or_tril / scale_or_tril_other) ** 2, axis=-1)
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
return maha_part, cov_part
def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
diff = x - y
if self.full_cov:
solved = jax.scipy.linalg.solve_triangular(
scale_or_tril, diff[..., None], lower=True
)
return jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
else:
return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1)
def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
if self.full_cov:
return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1)
else:
return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1)
def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(x ** 2, axis=(-2, -1))
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
return old_mean + (mean - old_mean) * jnp.sqrt(
self.mean_bound / (mean_part + 1e-8)
)[..., None]
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray:
if self.full_cov:
cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2))
old_cov = jnp.matmul(old_scale_or_tril, jnp.swapaxes(old_scale_or_tril, -1, -2))
else:
cov = scale_or_tril ** 2
old_cov = old_scale_or_tril ** 2
mask = cov_part > self.cov_bound
proj_scale_or_tril = jnp.zeros_like(scale_or_tril)
proj_scale_or_tril = jnp.where(~mask, scale_or_tril, proj_scale_or_tril)
if mask.any():
if self.full_cov:
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) & mask
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
mask = mask & ~is_invalid
chol = jnp.linalg.cholesky(proj_cov)
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
else:
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
jnp.isinf(proj_cov.mean(axis=-1)) |
(proj_cov.min(axis=-1) < 0)) & mask
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
mask = mask & ~is_invalid
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), proj_scale_or_tril)
return proj_scale_or_tril
def _validate_inputs(self, policy_params, old_policy_params):
required_keys = ["loc", "scale"]
for key in required_keys:
if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters")
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def project_full_covariance(cov, chol, old_chol, eps_cov):
"""JAX wrapper for C++ full covariance projection"""
try:
# Convert JAX arrays to numpy for C++ function
cov_np = np.asarray(cov)
chol_np = np.asarray(chol)
old_chol_np = np.asarray(old_chol)
batch_shape = cov_np.shape[0]
dim = cov_np.shape[-1]
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
# Convert back to JAX array
return jnp.array(proj_cov)
except Exception as e:
print(f"Full covariance projection failed: {e}")
return old_chol # Return old values on failure
def project_full_covariance_fwd(cov, chol, old_chol, eps_cov):
y = project_full_covariance(cov, chol, old_chol, eps_cov)
return y, (cov, chol, old_chol)
def project_full_covariance_bwd(eps_cov, res, g):
cov, chol, old_chol = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None, None
# Register VJP rule for full covariance projection
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def project_diag_covariance(cov, old_cov, eps_cov):
"""JAX wrapper for C++ diagonal covariance projection"""
# Convert JAX arrays to numpy for C++ function
cov_np = np.asarray(cov)
old_cov_np = np.asarray(old_cov)
batch_shape = cov_np.shape[0]
dim = cov_np.shape[-1]
eps = eps_cov * np.ones(batch_shape)
# Create C++ projection operator directly
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ projection
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
# Convert back to JAX array
return jnp.array(proj_cov)
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
y = project_diag_covariance(cov, old_cov, eps_cov)
return y, (cov, old_cov)
def project_diag_covariance_bwd(eps_cov, res, g):
cov, old_cov = res
# Convert to numpy for C++ backward pass
g_np = np.asarray(g)
batch_shape = g_np.shape[0]
dim = g_np.shape[-1]
# Get C++ projection operator
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
# Run C++ backward pass
grad_cov = p_op.backward(g_np)
# Convert back to JAX array
return jnp.array(grad_cov), None
# Register VJP rule for diagonal covariance projection
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
if not cpp_projection_available:
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")

View File

@ -0,0 +1,108 @@
import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict, Tuple
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
"""
'Converts' scale_tril to scale_sqrt.
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
"""
return scale_tril
def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray],
q: Tuple[jnp.ndarray, jnp.ndarray],
scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]:
mean, scale_or_sqrt = p
mean_other, scale_or_sqrt_other = q
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
if scale_or_sqrt.ndim == mean.ndim: # Diagonal case
cov = scale_or_sqrt ** 2
cov_other = scale_or_sqrt_other ** 2
if scale_prec:
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
sqrt_inv_other = 1 / scale_or_sqrt_other
c = sqrt_inv_other ** 2 * cov
cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1)
else:
cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1)
else: # Full covariance case
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2))
cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2))
if scale_prec:
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity)
c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2)
cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
else:
cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
return mean_part, cov_part
class WassersteinProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, scale_prec: bool = False,
contextual_std: bool = True, full_cov: bool = False):
assert not full_cov, "Full covariance is not supported for Wasserstein projection"
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=False)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"])
mean_part, cov_part = gaussian_wasserstein_commutative(
(mean, scale_or_sqrt),
(old_mean, old_scale_or_sqrt),
self.scale_prec
)
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
return {"loc": proj_mean, "scale": proj_scale_or_sqrt}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
mean_part, cov_part = gaussian_wasserstein_commutative(
(mean, scale_or_sqrt),
(proj_mean, proj_scale_or_sqrt),
self.scale_prec
)
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
diff = mean - old_mean
norm = jnp.sqrt(mean_part)
return jnp.where(norm > self.mean_bound,
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = jnp.sqrt(cov_part)
return jnp.where(norm > self.cov_bound,
old_scale_or_sqrt + diff * self.cov_bound / norm[..., None],
scale_or_sqrt)
else: # Full covariance case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True)
return jnp.where(norm > self.cov_bound,
old_scale_or_sqrt + diff * self.cov_bound / norm,
scale_or_sqrt)