itpal_jax/itpal_jax/wasserstein_projection.py

99 lines
4.3 KiB
Python

import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict, Tuple
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
"""
'Converts' scale_tril to scale_sqrt.
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
"""
return scale_tril
class WassersteinProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, scale_prec: bool = False,
contextual_std: bool = True, full_cov: bool = False):
assert not full_cov, "Full covariance is not supported for Wasserstein projection"
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=False)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
mean = policy_params["loc"] # shape: (batch_size, dim)
old_mean = old_policy_params["loc"]
scale = policy_params["scale"] # shape: (batch_size, dim)
old_scale = old_policy_params["scale"]
original_shape = scale.shape # Store original shape for broadcasting back
if not self.contextual_std:
# Use only first batch element for scale
scale = scale[0] # shape: (dim,)
old_scale = old_scale[0] # shape: (dim,)
mean_part, scale_part = self._gaussian_wasserstein(
(mean, scale),
(old_mean, old_scale)
)
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_scale = self._scale_projection(scale, old_scale, scale_part)
if not self.contextual_std:
# Broadcast single scale to all batch elements
proj_scale = jnp.broadcast_to(proj_scale[None, :], original_shape)
return {"loc": proj_mean, "scale": proj_scale}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
mean_part, cov_part = self._gaussian_wasserstein(
(mean, scale_or_sqrt),
(proj_mean, proj_scale_or_sqrt)
)
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
diff = mean - old_mean
norm = jnp.sqrt(mean_part)
return jnp.where(norm > self.mean_bound,
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters (standard deviations for diagonal case)"""
diff = scale - old_scale
norm = jnp.sqrt(scale_part)
if scale.ndim == 2: # Batched scale
norm = norm[..., None]
return jnp.where(norm > self.cov_bound,
old_scale + diff * self.cov_bound / norm,
scale)
def _gaussian_wasserstein(self, p, q):
mean, scale = p
mean_other, scale_other = q
# Keep batch dimension by only summing over feature dimension
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
if scale.ndim == mean.ndim: # Batched scale
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
else: # Non-contextual scale (single scale for all batches)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
return mean_part, cov_part