99 lines
4.3 KiB
Python
99 lines
4.3 KiB
Python
import jax.numpy as jnp
|
|
from .base_projection import BaseProjection
|
|
from typing import Dict, Tuple
|
|
|
|
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
|
"""
|
|
'Converts' scale_tril to scale_sqrt.
|
|
|
|
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
|
|
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
|
|
"""
|
|
return scale_tril
|
|
|
|
class WassersteinProjection(BaseProjection):
|
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
|
cov_bound: float = 0.01, scale_prec: bool = False,
|
|
contextual_std: bool = True, full_cov: bool = False):
|
|
assert not full_cov, "Full covariance is not supported for Wasserstein projection"
|
|
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
|
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=False)
|
|
self.scale_prec = scale_prec
|
|
|
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
|
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
|
|
|
|
mean = policy_params["loc"] # shape: (batch_size, dim)
|
|
old_mean = old_policy_params["loc"]
|
|
scale = policy_params["scale"] # shape: (batch_size, dim)
|
|
old_scale = old_policy_params["scale"]
|
|
|
|
original_shape = scale.shape # Store original shape for broadcasting back
|
|
|
|
if not self.contextual_std:
|
|
# Use only first batch element for scale
|
|
scale = scale[0] # shape: (dim,)
|
|
old_scale = old_scale[0] # shape: (dim,)
|
|
|
|
mean_part, scale_part = self._gaussian_wasserstein(
|
|
(mean, scale),
|
|
(old_mean, old_scale)
|
|
)
|
|
|
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
|
proj_scale = self._scale_projection(scale, old_scale, scale_part)
|
|
|
|
if not self.contextual_std:
|
|
# Broadcast single scale to all batch elements
|
|
proj_scale = jnp.broadcast_to(proj_scale[None, :], original_shape)
|
|
|
|
return {"loc": proj_mean, "scale": proj_scale}
|
|
|
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
|
mean = policy_params["loc"]
|
|
proj_mean = proj_policy_params["loc"]
|
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
|
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
|
|
mean_part, cov_part = self._gaussian_wasserstein(
|
|
(mean, scale_or_sqrt),
|
|
(proj_mean, proj_scale_or_sqrt)
|
|
)
|
|
w2 = mean_part + cov_part
|
|
return w2.mean() * self.trust_region_coeff
|
|
|
|
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
|
mean_part: jnp.ndarray) -> jnp.ndarray:
|
|
diff = mean - old_mean
|
|
norm = jnp.sqrt(mean_part)
|
|
return jnp.where(norm > self.mean_bound,
|
|
old_mean + diff * self.mean_bound / norm[..., None],
|
|
mean)
|
|
|
|
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
|
scale_part: jnp.ndarray) -> jnp.ndarray:
|
|
"""Project scale parameters (standard deviations for diagonal case)"""
|
|
diff = scale - old_scale
|
|
norm = jnp.sqrt(scale_part)
|
|
|
|
if scale.ndim == 2: # Batched scale
|
|
norm = norm[..., None]
|
|
|
|
return jnp.where(norm > self.cov_bound,
|
|
old_scale + diff * self.cov_bound / norm,
|
|
scale)
|
|
|
|
def _gaussian_wasserstein(self, p, q):
|
|
mean, scale = p
|
|
mean_other, scale_other = q
|
|
|
|
# Keep batch dimension by only summing over feature dimension
|
|
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
|
|
|
|
if scale.ndim == mean.ndim: # Batched scale
|
|
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
|
|
else: # Non-contextual scale (single scale for all batches)
|
|
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
|
|
|
|
return mean_part, cov_part |