to git we go
This commit is contained in:
commit
9f85217a47
24
itpal_jax/__init__.py
Normal file
24
itpal_jax/__init__.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
JAX implementations of various policy projection methods.
|
||||||
|
|
||||||
|
Available projections:
|
||||||
|
- BaseProjection: Abstract base class for projections
|
||||||
|
- IdentityProjection: Simple identity mapping
|
||||||
|
- FrobeniusProjection: Frobenius norm-based projection
|
||||||
|
- WassersteinProjection: Wasserstein distance-based projection
|
||||||
|
- KLProjection: KL divergence-based projection (requires C++ backend)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from .identity_projection import IdentityProjection
|
||||||
|
from .frobenius_projection import FrobeniusProjection
|
||||||
|
from .wasserstein_projection import WassersteinProjection
|
||||||
|
from .kl_projection import KLProjection
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BaseProjection',
|
||||||
|
'IdentityProjection',
|
||||||
|
'FrobeniusProjection',
|
||||||
|
'WassersteinProjection',
|
||||||
|
'KLProjection',
|
||||||
|
]
|
35
itpal_jax/base_projection.py
Normal file
35
itpal_jax/base_projection.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
class BaseProjection(ABC):
|
||||||
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
|
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
|
||||||
|
self.trust_region_coeff = trust_region_coeff
|
||||||
|
self.mean_bound = mean_bound
|
||||||
|
self.cov_bound = cov_bound
|
||||||
|
self.full_cov = full_cov
|
||||||
|
self.contextual_std = contextual_std
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
if not self.full_cov:
|
||||||
|
return jnp.diag(params["scale"] ** 2)
|
||||||
|
else:
|
||||||
|
scale_tril = params["scale_tril"]
|
||||||
|
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
|
||||||
|
|
||||||
|
def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
if not self.full_cov:
|
||||||
|
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1))
|
||||||
|
else:
|
||||||
|
return jnp.linalg.cholesky(cov)
|
20
itpal_jax/exception_projection.py
Normal file
20
itpal_jax/exception_projection.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from .base_projection import BaseProjection
|
||||||
|
from functools import partial
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
class ExceptionProjection(BaseProjection):
|
||||||
|
def __init__(self, msg, *args, **kwargs):
|
||||||
|
self.msg = msg
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def makeExceptionProjection(msg: str):
|
||||||
|
return partial(ExceptionProjection, msg)
|
78
itpal_jax/frobenius_projection.py
Normal file
78
itpal_jax/frobenius_projection.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
class FrobeniusProjection(BaseProjection):
|
||||||
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
|
cov_bound: float = 0.01, scale_prec: bool = False,
|
||||||
|
contextual_std: bool = True, full_cov: bool = False):
|
||||||
|
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
||||||
|
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
|
||||||
|
self.scale_prec = scale_prec
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
|
mean = policy_params["loc"]
|
||||||
|
old_mean = old_policy_params["loc"]
|
||||||
|
|
||||||
|
cov = self._calc_covariance(policy_params)
|
||||||
|
old_cov = self._calc_covariance(old_policy_params)
|
||||||
|
|
||||||
|
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||||
|
|
||||||
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
|
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||||
|
|
||||||
|
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
|
||||||
|
return {"loc": proj_mean, "scale": scale_or_scale_tril}
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
mean = policy_params["loc"]
|
||||||
|
proj_mean = proj_policy_params["loc"]
|
||||||
|
|
||||||
|
cov = self._calc_covariance(policy_params)
|
||||||
|
proj_cov = self._calc_covariance(proj_policy_params)
|
||||||
|
|
||||||
|
mean_diff = jnp.sum(jnp.square(mean - proj_mean), axis=-1)
|
||||||
|
cov_diff = jnp.sum(jnp.square(cov - proj_cov), axis=(-2, -1))
|
||||||
|
|
||||||
|
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
|
||||||
|
|
||||||
|
def _gaussian_frobenius(self, p, q):
|
||||||
|
mean, cov = p
|
||||||
|
old_mean, old_cov = q
|
||||||
|
|
||||||
|
if self.scale_prec:
|
||||||
|
prec_old = jnp.linalg.inv(old_cov)
|
||||||
|
mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
|
||||||
|
cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
|
||||||
|
else:
|
||||||
|
mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
|
||||||
|
cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
|
||||||
|
|
||||||
|
return mean_part, cov_part
|
||||||
|
|
||||||
|
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||||
|
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
diff = mean - old_mean
|
||||||
|
norm = jnp.sqrt(mean_part)
|
||||||
|
return jnp.where(norm > self.mean_bound,
|
||||||
|
old_mean + diff * self.mean_bound / norm[..., None],
|
||||||
|
mean)
|
||||||
|
|
||||||
|
def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray,
|
||||||
|
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
batch_shape = cov.shape[:-2]
|
||||||
|
cov_mask = cov_part > self.cov_bound
|
||||||
|
|
||||||
|
eta = jnp.ones(batch_shape, dtype=cov.dtype)
|
||||||
|
eta = jnp.where(cov_mask,
|
||||||
|
jnp.sqrt(cov_part / self.cov_bound) - 1.,
|
||||||
|
eta)
|
||||||
|
eta = jnp.maximum(-eta, eta)
|
||||||
|
|
||||||
|
new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
||||||
|
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
|
||||||
|
|
||||||
|
return proj_cov
|
17
itpal_jax/identity_projection.py
Normal file
17
itpal_jax/identity_projection.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
class IdentityProjection(BaseProjection):
|
||||||
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
|
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
|
||||||
|
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
||||||
|
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
|
return policy_params
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
return jnp.array(0.0)
|
244
itpal_jax/kl_projection.py
Normal file
244
itpal_jax/kl_projection.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
try:
|
||||||
|
import cpp_projection
|
||||||
|
cpp_projection_available = True
|
||||||
|
except ImportError:
|
||||||
|
cpp_projection_available = False
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Tuple, Any
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from .exception_projection import makeExceptionProjection
|
||||||
|
|
||||||
|
MAX_EVAL = 1000
|
||||||
|
|
||||||
|
class KLProjection(BaseProjection):
|
||||||
|
"""KL divergence-based projection for Gaussian policies.
|
||||||
|
|
||||||
|
This class implements KL divergence projection using JAX, with C++ backend
|
||||||
|
for efficient projection operations. It supports both diagonal and full
|
||||||
|
covariance matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trust_region_coeff (float): Coefficient for trust region loss
|
||||||
|
mean_bound (float): Bound for mean projection
|
||||||
|
cov_bound (float): Bound for covariance projection
|
||||||
|
contextual_std (bool): Whether to use contextual standard deviations
|
||||||
|
"""
|
||||||
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
|
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
|
||||||
|
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
||||||
|
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
|
self._validate_inputs(policy_params, old_policy_params)
|
||||||
|
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
|
||||||
|
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"]
|
||||||
|
|
||||||
|
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril),
|
||||||
|
(old_mean, old_scale_or_tril))
|
||||||
|
|
||||||
|
if not self.contextual_std:
|
||||||
|
scale_or_tril = scale_or_tril[:1]
|
||||||
|
old_scale_or_tril = old_scale_or_tril[:1]
|
||||||
|
cov_part = cov_part[:1]
|
||||||
|
|
||||||
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
|
proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
|
||||||
|
|
||||||
|
if not self.contextual_std:
|
||||||
|
proj_scale_or_tril = jnp.broadcast_to(
|
||||||
|
proj_scale_or_tril,
|
||||||
|
(mean.shape[0],) + proj_scale_or_tril.shape[1:]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
|
||||||
|
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params["scale"]
|
||||||
|
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||||
|
return jnp.mean(kl) * self.trust_region_coeff
|
||||||
|
|
||||||
|
def _gaussian_kl(self, p: Tuple[jnp.ndarray, jnp.ndarray],
|
||||||
|
q: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
|
mean, scale_or_tril = p
|
||||||
|
mean_other, scale_or_tril_other = q
|
||||||
|
k = mean.shape[-1]
|
||||||
|
|
||||||
|
maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
|
||||||
|
|
||||||
|
det_term = self._log_determinant(scale_or_tril)
|
||||||
|
det_term_other = self._log_determinant(scale_or_tril_other)
|
||||||
|
|
||||||
|
if self.full_cov:
|
||||||
|
trace_part = self._batched_trace_square(
|
||||||
|
jax.scipy.linalg.solve_triangular(
|
||||||
|
scale_or_tril_other, scale_or_tril, lower=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
trace_part = jnp.sum((scale_or_tril / scale_or_tril_other) ** 2, axis=-1)
|
||||||
|
|
||||||
|
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
|
||||||
|
|
||||||
|
return maha_part, cov_part
|
||||||
|
|
||||||
|
def _maha(self, x: jnp.ndarray, y: jnp.ndarray, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
diff = x - y
|
||||||
|
if self.full_cov:
|
||||||
|
solved = jax.scipy.linalg.solve_triangular(
|
||||||
|
scale_or_tril, diff[..., None], lower=True
|
||||||
|
)
|
||||||
|
return jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
|
||||||
|
else:
|
||||||
|
return jnp.sum(jnp.square(diff / scale_or_tril), axis=-1)
|
||||||
|
|
||||||
|
def _log_determinant(self, scale_or_tril: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
if self.full_cov:
|
||||||
|
return 2 * jnp.sum(jnp.log(jnp.diagonal(scale_or_tril, axis1=-2, axis2=-1)), axis=-1)
|
||||||
|
else:
|
||||||
|
return 2 * jnp.sum(jnp.log(scale_or_tril), axis=-1)
|
||||||
|
|
||||||
|
def _batched_trace_square(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
return jnp.sum(x ** 2, axis=(-2, -1))
|
||||||
|
|
||||||
|
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||||
|
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
return old_mean + (mean - old_mean) * jnp.sqrt(
|
||||||
|
self.mean_bound / (mean_part + 1e-8)
|
||||||
|
)[..., None]
|
||||||
|
|
||||||
|
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray, cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
if self.full_cov:
|
||||||
|
cov = jnp.matmul(scale_or_tril, jnp.swapaxes(scale_or_tril, -1, -2))
|
||||||
|
old_cov = jnp.matmul(old_scale_or_tril, jnp.swapaxes(old_scale_or_tril, -1, -2))
|
||||||
|
else:
|
||||||
|
cov = scale_or_tril ** 2
|
||||||
|
old_cov = old_scale_or_tril ** 2
|
||||||
|
|
||||||
|
mask = cov_part > self.cov_bound
|
||||||
|
proj_scale_or_tril = jnp.zeros_like(scale_or_tril)
|
||||||
|
proj_scale_or_tril = jnp.where(~mask, scale_or_tril, proj_scale_or_tril)
|
||||||
|
|
||||||
|
if mask.any():
|
||||||
|
if self.full_cov:
|
||||||
|
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
||||||
|
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) & mask
|
||||||
|
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
|
||||||
|
mask = mask & ~is_invalid
|
||||||
|
chol = jnp.linalg.cholesky(proj_cov)
|
||||||
|
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
|
||||||
|
else:
|
||||||
|
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
||||||
|
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
||||||
|
jnp.isinf(proj_cov.mean(axis=-1)) |
|
||||||
|
(proj_cov.min(axis=-1) < 0)) & mask
|
||||||
|
proj_scale_or_tril = jnp.where(is_invalid, old_scale_or_tril, proj_scale_or_tril)
|
||||||
|
mask = mask & ~is_invalid
|
||||||
|
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), proj_scale_or_tril)
|
||||||
|
|
||||||
|
return proj_scale_or_tril
|
||||||
|
|
||||||
|
def _validate_inputs(self, policy_params, old_policy_params):
|
||||||
|
required_keys = ["loc", "scale"]
|
||||||
|
for key in required_keys:
|
||||||
|
if key not in policy_params or key not in old_policy_params:
|
||||||
|
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
||||||
|
|
||||||
|
@partial(jax.custom_vjp, nondiff_argnums=(3,))
|
||||||
|
def project_full_covariance(cov, chol, old_chol, eps_cov):
|
||||||
|
"""JAX wrapper for C++ full covariance projection"""
|
||||||
|
try:
|
||||||
|
# Convert JAX arrays to numpy for C++ function
|
||||||
|
cov_np = np.asarray(cov)
|
||||||
|
chol_np = np.asarray(chol)
|
||||||
|
old_chol_np = np.asarray(old_chol)
|
||||||
|
batch_shape = cov_np.shape[0]
|
||||||
|
dim = cov_np.shape[-1]
|
||||||
|
eps = eps_cov * np.ones(batch_shape)
|
||||||
|
|
||||||
|
# Create C++ projection operator directly
|
||||||
|
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||||
|
|
||||||
|
# Run C++ projection
|
||||||
|
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
|
||||||
|
|
||||||
|
# Convert back to JAX array
|
||||||
|
return jnp.array(proj_cov)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Full covariance projection failed: {e}")
|
||||||
|
return old_chol # Return old values on failure
|
||||||
|
|
||||||
|
def project_full_covariance_fwd(cov, chol, old_chol, eps_cov):
|
||||||
|
y = project_full_covariance(cov, chol, old_chol, eps_cov)
|
||||||
|
return y, (cov, chol, old_chol)
|
||||||
|
|
||||||
|
def project_full_covariance_bwd(eps_cov, res, g):
|
||||||
|
cov, chol, old_chol = res
|
||||||
|
|
||||||
|
# Convert to numpy for C++ backward pass
|
||||||
|
g_np = np.asarray(g)
|
||||||
|
batch_shape = g_np.shape[0]
|
||||||
|
dim = g_np.shape[-1]
|
||||||
|
|
||||||
|
# Get C++ projection operator
|
||||||
|
p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||||
|
|
||||||
|
# Run C++ backward pass
|
||||||
|
grad_cov = p_op.backward(g_np)
|
||||||
|
|
||||||
|
# Convert back to JAX array
|
||||||
|
return jnp.array(grad_cov), None, None
|
||||||
|
|
||||||
|
# Register VJP rule for full covariance projection
|
||||||
|
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
|
||||||
|
|
||||||
|
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
||||||
|
def project_diag_covariance(cov, old_cov, eps_cov):
|
||||||
|
"""JAX wrapper for C++ diagonal covariance projection"""
|
||||||
|
# Convert JAX arrays to numpy for C++ function
|
||||||
|
cov_np = np.asarray(cov)
|
||||||
|
old_cov_np = np.asarray(old_cov)
|
||||||
|
batch_shape = cov_np.shape[0]
|
||||||
|
dim = cov_np.shape[-1]
|
||||||
|
eps = eps_cov * np.ones(batch_shape)
|
||||||
|
|
||||||
|
# Create C++ projection operator directly
|
||||||
|
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||||
|
|
||||||
|
# Run C++ projection
|
||||||
|
proj_cov = p_op.forward(eps, old_cov_np, cov_np)
|
||||||
|
|
||||||
|
# Convert back to JAX array
|
||||||
|
return jnp.array(proj_cov)
|
||||||
|
|
||||||
|
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
|
||||||
|
y = project_diag_covariance(cov, old_cov, eps_cov)
|
||||||
|
return y, (cov, old_cov)
|
||||||
|
|
||||||
|
def project_diag_covariance_bwd(eps_cov, res, g):
|
||||||
|
cov, old_cov = res
|
||||||
|
|
||||||
|
# Convert to numpy for C++ backward pass
|
||||||
|
g_np = np.asarray(g)
|
||||||
|
batch_shape = g_np.shape[0]
|
||||||
|
dim = g_np.shape[-1]
|
||||||
|
|
||||||
|
# Get C++ projection operator
|
||||||
|
p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
|
||||||
|
|
||||||
|
# Run C++ backward pass
|
||||||
|
grad_cov = p_op.backward(g_np)
|
||||||
|
|
||||||
|
# Convert back to JAX array
|
||||||
|
return jnp.array(grad_cov), None
|
||||||
|
|
||||||
|
# Register VJP rule for diagonal covariance projection
|
||||||
|
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
|
||||||
|
|
||||||
|
if not cpp_projection_available:
|
||||||
|
KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")
|
108
itpal_jax/wasserstein_projection.py
Normal file
108
itpal_jax/wasserstein_projection.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""
|
||||||
|
'Converts' scale_tril to scale_sqrt.
|
||||||
|
|
||||||
|
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
|
||||||
|
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
|
||||||
|
"""
|
||||||
|
return scale_tril
|
||||||
|
|
||||||
|
def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray],
|
||||||
|
q: Tuple[jnp.ndarray, jnp.ndarray],
|
||||||
|
scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
|
mean, scale_or_sqrt = p
|
||||||
|
mean_other, scale_or_sqrt_other = q
|
||||||
|
|
||||||
|
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
|
||||||
|
|
||||||
|
if scale_or_sqrt.ndim == mean.ndim: # Diagonal case
|
||||||
|
cov = scale_or_sqrt ** 2
|
||||||
|
cov_other = scale_or_sqrt_other ** 2
|
||||||
|
if scale_prec:
|
||||||
|
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
|
||||||
|
sqrt_inv_other = 1 / scale_or_sqrt_other
|
||||||
|
c = sqrt_inv_other ** 2 * cov
|
||||||
|
cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1)
|
||||||
|
else:
|
||||||
|
cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1)
|
||||||
|
else: # Full covariance case
|
||||||
|
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
|
||||||
|
cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2))
|
||||||
|
cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2))
|
||||||
|
if scale_prec:
|
||||||
|
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
|
||||||
|
sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity)
|
||||||
|
c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2)
|
||||||
|
cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
|
||||||
|
else:
|
||||||
|
cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
|
||||||
|
|
||||||
|
return mean_part, cov_part
|
||||||
|
|
||||||
|
class WassersteinProjection(BaseProjection):
|
||||||
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
|
cov_bound: float = 0.01, scale_prec: bool = False,
|
||||||
|
contextual_std: bool = True, full_cov: bool = False):
|
||||||
|
assert not full_cov, "Full covariance is not supported for Wasserstein projection"
|
||||||
|
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
||||||
|
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=False)
|
||||||
|
self.scale_prec = scale_prec
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
|
mean = policy_params["loc"]
|
||||||
|
old_mean = old_policy_params["loc"]
|
||||||
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
||||||
|
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"])
|
||||||
|
|
||||||
|
mean_part, cov_part = gaussian_wasserstein_commutative(
|
||||||
|
(mean, scale_or_sqrt),
|
||||||
|
(old_mean, old_scale_or_sqrt),
|
||||||
|
self.scale_prec
|
||||||
|
)
|
||||||
|
|
||||||
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
|
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
|
||||||
|
|
||||||
|
return {"loc": proj_mean, "scale": proj_scale_or_sqrt}
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
mean = policy_params["loc"]
|
||||||
|
proj_mean = proj_policy_params["loc"]
|
||||||
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
||||||
|
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
|
||||||
|
mean_part, cov_part = gaussian_wasserstein_commutative(
|
||||||
|
(mean, scale_or_sqrt),
|
||||||
|
(proj_mean, proj_scale_or_sqrt),
|
||||||
|
self.scale_prec
|
||||||
|
)
|
||||||
|
w2 = mean_part + cov_part
|
||||||
|
return w2.mean() * self.trust_region_coeff
|
||||||
|
|
||||||
|
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
||||||
|
mean_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
diff = mean - old_mean
|
||||||
|
norm = jnp.sqrt(mean_part)
|
||||||
|
return jnp.where(norm > self.mean_bound,
|
||||||
|
old_mean + diff * self.mean_bound / norm[..., None],
|
||||||
|
mean)
|
||||||
|
|
||||||
|
def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray,
|
||||||
|
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case
|
||||||
|
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||||
|
norm = jnp.sqrt(cov_part)
|
||||||
|
return jnp.where(norm > self.cov_bound,
|
||||||
|
old_scale_or_sqrt + diff * self.cov_bound / norm[..., None],
|
||||||
|
scale_or_sqrt)
|
||||||
|
else: # Full covariance case
|
||||||
|
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||||
|
norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True)
|
||||||
|
return jnp.where(norm > self.cov_bound,
|
||||||
|
old_scale_or_sqrt + diff * self.cov_bound / norm,
|
||||||
|
scale_or_sqrt)
|
Loading…
Reference in New Issue
Block a user