From e0eb46e14c1e65a1b94093a94ef4afcc54a3bf48 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 21 Dec 2024 17:48:53 +0100 Subject: [PATCH] Many fixes, that should have been multiple commits... --- itpal_jax/base_projection.py | 27 ++++--- itpal_jax/frobenius_projection.py | 36 ++++++++-- itpal_jax/kl_projection.py | 25 +++++-- itpal_jax/wasserstein_projection.py | 107 +++++++++++++--------------- 4 files changed, 118 insertions(+), 77 deletions(-) diff --git a/itpal_jax/base_projection.py b/itpal_jax/base_projection.py index 3ec32ea..7c28ac5 100644 --- a/itpal_jax/base_projection.py +++ b/itpal_jax/base_projection.py @@ -14,22 +14,29 @@ class BaseProjection(ABC): @abstractmethod def project(self, policy_params: Dict[str, jnp.ndarray], old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: - pass - - @abstractmethod - def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], - proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + """Project policy parameters. + + Args: + policy_params: Dictionary with: + - 'loc': mean parameters (batch_size, dim) + - 'scale': standard deviations (batch_size, dim) if full_cov=False + - 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True + old_policy_params: Same format as policy_params + """ pass def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + """Convert scale representation to covariance matrix.""" if not self.full_cov: - return jnp.diag(params["scale"] ** 2) + scale = params["scale"] # standard deviations + return jnp.square(scale) # diagonal covariance else: - scale_tril = params["scale_tril"] + scale_tril = params["scale_tril"] # Cholesky factor return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2)) - def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray: + def _calc_scale_from_cov(self, cov: jnp.ndarray) -> jnp.ndarray: + """Convert covariance matrix back to appropriate scale representation.""" if not self.full_cov: - return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) + return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) # standard deviations else: - return jnp.linalg.cholesky(cov) \ No newline at end of file + return jnp.linalg.cholesky(cov) # Cholesky factor \ No newline at end of file diff --git a/itpal_jax/frobenius_projection.py b/itpal_jax/frobenius_projection.py index 3cd3cf7..2ec1d7e 100644 --- a/itpal_jax/frobenius_projection.py +++ b/itpal_jax/frobenius_projection.py @@ -15,16 +15,36 @@ class FrobeniusProjection(BaseProjection): mean = policy_params["loc"] old_mean = old_policy_params["loc"] + # Convert to covariance representation cov = self._calc_covariance(policy_params) old_cov = self._calc_covariance(old_policy_params) - mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) + if not self.contextual_std: + # Use only first batch element for scale, keeping batch dim + cov = cov[:1] # shape: (1, dim) + old_cov = old_cov[:1] # shape: (1, dim) + # Project in covariance space + mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_cov = self._cov_projection(cov, old_cov, cov_part) - scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov) - return {"loc": proj_mean, "scale": scale_or_scale_tril} + # Convert back to appropriate scale representation + scale_or_tril = self._calc_scale_from_cov(proj_cov) + + if not self.contextual_std: + # Broadcast scale to match original shape + if self.full_cov: + target_shape = policy_params["scale_tril"].shape + else: + target_shape = policy_params["scale"].shape + scale_or_tril = jnp.broadcast_to(scale_or_tril, target_shape) + + # Return with correct key + if self.full_cov: + return {"loc": proj_mean, "scale_tril": scale_or_tril} + else: + return {"loc": proj_mean, "scale": scale_or_tril} def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: @@ -72,7 +92,13 @@ class FrobeniusProjection(BaseProjection): eta) eta = jnp.maximum(-eta, eta) - new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] - proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov) + if self.full_cov: + new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] + else: + # For diagonal case, simple broadcasting + new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None] + + proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None], + new_cov, cov) return proj_cov \ No newline at end of file diff --git a/itpal_jax/kl_projection.py b/itpal_jax/kl_projection.py index 17dd332..d75feef 100644 --- a/itpal_jax/kl_projection.py +++ b/itpal_jax/kl_projection.py @@ -34,8 +34,17 @@ class KLProjection(BaseProjection): def project(self, policy_params: Dict[str, jnp.ndarray], old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: self._validate_inputs(policy_params, old_policy_params) - mean, scale_or_tril = policy_params["loc"], policy_params["scale"] - old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"] + + # Get the right scale parameter based on full_cov + mean = policy_params["loc"] + old_mean = old_policy_params["loc"] + + if self.full_cov: + scale_or_tril = policy_params["scale_tril"] + old_scale_or_tril = old_policy_params["scale_tril"] + else: + scale_or_tril = policy_params["scale"] + old_scale_or_tril = old_policy_params["scale"] mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril)) @@ -54,7 +63,10 @@ class KLProjection(BaseProjection): (mean.shape[0],) + proj_scale_or_tril.shape[1:] ) - return {"loc": proj_mean, "scale": proj_scale_or_tril} + if self.full_cov: + return {"loc": proj_mean, "scale_tril": proj_scale_or_tril} + else: + return {"loc": proj_mean, "scale": proj_scale_or_tril} def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: @@ -144,7 +156,12 @@ class KLProjection(BaseProjection): return proj_scale_or_tril def _validate_inputs(self, policy_params, old_policy_params): - required_keys = ["loc", "scale"] + """Validate input parameters have correct format.""" + if self.full_cov: + required_keys = ["loc", "scale_tril"] + else: + required_keys = ["loc", "scale"] + for key in required_keys: if key not in policy_params or key not in old_policy_params: raise KeyError(f"Missing required key '{key}' in policy parameters") diff --git a/itpal_jax/wasserstein_projection.py b/itpal_jax/wasserstein_projection.py index aa224cc..ecc5237 100644 --- a/itpal_jax/wasserstein_projection.py +++ b/itpal_jax/wasserstein_projection.py @@ -11,38 +11,6 @@ def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray: """ return scale_tril -def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray], - q: Tuple[jnp.ndarray, jnp.ndarray], - scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]: - mean, scale_or_sqrt = p - mean_other, scale_or_sqrt_other = q - - mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) - - if scale_or_sqrt.ndim == mean.ndim: # Diagonal case - cov = scale_or_sqrt ** 2 - cov_other = scale_or_sqrt_other ** 2 - if scale_prec: - identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype) - sqrt_inv_other = 1 / scale_or_sqrt_other - c = sqrt_inv_other ** 2 * cov - cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1) - else: - cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1) - else: # Full covariance case - # Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition - cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2)) - cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2)) - if scale_prec: - identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype) - sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity) - c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2) - cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt) - else: - cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt) - - return mean_part, cov_part - class WassersteinProjection(BaseProjection): def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, @@ -54,21 +22,33 @@ class WassersteinProjection(BaseProjection): def project(self, policy_params: Dict[str, jnp.ndarray], old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: - mean = policy_params["loc"] + assert not self.full_cov, "Wasserstein projection only supports diagonal covariance" + + mean = policy_params["loc"] # shape: (batch_size, dim) old_mean = old_policy_params["loc"] - scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"]) - old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"]) + scale = policy_params["scale"] # shape: (batch_size, dim) + old_scale = old_policy_params["scale"] - mean_part, cov_part = gaussian_wasserstein_commutative( - (mean, scale_or_sqrt), - (old_mean, old_scale_or_sqrt), - self.scale_prec + original_shape = scale.shape # Store original shape for broadcasting back + + if not self.contextual_std: + # Use only first batch element for scale + scale = scale[0] # shape: (dim,) + old_scale = old_scale[0] # shape: (dim,) + + mean_part, scale_part = self._gaussian_wasserstein( + (mean, scale), + (old_mean, old_scale) ) proj_mean = self._mean_projection(mean, old_mean, mean_part) - proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part) + proj_scale = self._scale_projection(scale, old_scale, scale_part) - return {"loc": proj_mean, "scale": proj_scale_or_sqrt} + if not self.contextual_std: + # Broadcast single scale to all batch elements + proj_scale = jnp.broadcast_to(proj_scale[None, :], original_shape) + + return {"loc": proj_mean, "scale": proj_scale} def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: @@ -76,10 +56,9 @@ class WassersteinProjection(BaseProjection): proj_mean = proj_policy_params["loc"] scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"]) proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"]) - mean_part, cov_part = gaussian_wasserstein_commutative( + mean_part, cov_part = self._gaussian_wasserstein( (mean, scale_or_sqrt), - (proj_mean, proj_scale_or_sqrt), - self.scale_prec + (proj_mean, proj_scale_or_sqrt) ) w2 = mean_part + cov_part return w2.mean() * self.trust_region_coeff @@ -92,17 +71,29 @@ class WassersteinProjection(BaseProjection): old_mean + diff * self.mean_bound / norm[..., None], mean) - def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray, - cov_part: jnp.ndarray) -> jnp.ndarray: - if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case - diff = scale_or_sqrt - old_scale_or_sqrt - norm = jnp.sqrt(cov_part) - return jnp.where(norm > self.cov_bound, - old_scale_or_sqrt + diff * self.cov_bound / norm[..., None], - scale_or_sqrt) - else: # Full covariance case - diff = scale_or_sqrt - old_scale_or_sqrt - norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True) - return jnp.where(norm > self.cov_bound, - old_scale_or_sqrt + diff * self.cov_bound / norm, - scale_or_sqrt) \ No newline at end of file + def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, + scale_part: jnp.ndarray) -> jnp.ndarray: + """Project scale parameters (standard deviations for diagonal case)""" + diff = scale - old_scale + norm = jnp.sqrt(scale_part) + + if scale.ndim == 2: # Batched scale + norm = norm[..., None] + + return jnp.where(norm > self.cov_bound, + old_scale + diff * self.cov_bound / norm, + scale) + + def _gaussian_wasserstein(self, p, q): + mean, scale = p + mean_other, scale_other = q + + # Keep batch dimension by only summing over feature dimension + mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,) + + if scale.ndim == mean.ndim: # Batched scale + cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1) + else: # Non-contextual scale (single scale for all batches) + cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale) + + return mean_part, cov_part \ No newline at end of file