Many fixes, that should have been multiple commits...

This commit is contained in:
Dominik Moritz Roth 2024-12-21 17:48:53 +01:00
parent e414d8c5b2
commit e0eb46e14c
4 changed files with 118 additions and 77 deletions

View File

@ -14,22 +14,29 @@ class BaseProjection(ABC):
@abstractmethod @abstractmethod
def project(self, policy_params: Dict[str, jnp.ndarray], def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
pass """Project policy parameters.
@abstractmethod Args:
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], policy_params: Dictionary with:
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: - 'loc': mean parameters (batch_size, dim)
- 'scale': standard deviations (batch_size, dim) if full_cov=False
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
old_policy_params: Same format as policy_params
"""
pass pass
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Convert scale representation to covariance matrix."""
if not self.full_cov: if not self.full_cov:
return jnp.diag(params["scale"] ** 2) scale = params["scale"] # standard deviations
return jnp.square(scale) # diagonal covariance
else: else:
scale_tril = params["scale_tril"] scale_tril = params["scale_tril"] # Cholesky factor
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2)) return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray: def _calc_scale_from_cov(self, cov: jnp.ndarray) -> jnp.ndarray:
"""Convert covariance matrix back to appropriate scale representation."""
if not self.full_cov: if not self.full_cov:
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) # standard deviations
else: else:
return jnp.linalg.cholesky(cov) return jnp.linalg.cholesky(cov) # Cholesky factor

View File

@ -15,16 +15,36 @@ class FrobeniusProjection(BaseProjection):
mean = policy_params["loc"] mean = policy_params["loc"]
old_mean = old_policy_params["loc"] old_mean = old_policy_params["loc"]
# Convert to covariance representation
cov = self._calc_covariance(policy_params) cov = self._calc_covariance(policy_params)
old_cov = self._calc_covariance(old_policy_params) old_cov = self._calc_covariance(old_policy_params)
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) if not self.contextual_std:
# Use only first batch element for scale, keeping batch dim
cov = cov[:1] # shape: (1, dim)
old_cov = old_cov[:1] # shape: (1, dim)
# Project in covariance space
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part) proj_cov = self._cov_projection(cov, old_cov, cov_part)
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov) # Convert back to appropriate scale representation
return {"loc": proj_mean, "scale": scale_or_scale_tril} scale_or_tril = self._calc_scale_from_cov(proj_cov)
if not self.contextual_std:
# Broadcast scale to match original shape
if self.full_cov:
target_shape = policy_params["scale_tril"].shape
else:
target_shape = policy_params["scale"].shape
scale_or_tril = jnp.broadcast_to(scale_or_tril, target_shape)
# Return with correct key
if self.full_cov:
return {"loc": proj_mean, "scale_tril": scale_or_tril}
else:
return {"loc": proj_mean, "scale": scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
@ -72,7 +92,13 @@ class FrobeniusProjection(BaseProjection):
eta) eta)
eta = jnp.maximum(-eta, eta) eta = jnp.maximum(-eta, eta)
new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] if self.full_cov:
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov) new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
else:
# For diagonal case, simple broadcasting
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
new_cov, cov)
return proj_cov return proj_cov

View File

@ -34,8 +34,17 @@ class KLProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray], def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
self._validate_inputs(policy_params, old_policy_params) self._validate_inputs(policy_params, old_policy_params)
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"] # Get the right scale parameter based on full_cov
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
if self.full_cov:
scale_or_tril = policy_params["scale_tril"]
old_scale_or_tril = old_policy_params["scale_tril"]
else:
scale_or_tril = policy_params["scale"]
old_scale_or_tril = old_policy_params["scale"]
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril),
(old_mean, old_scale_or_tril)) (old_mean, old_scale_or_tril))
@ -54,6 +63,9 @@ class KLProjection(BaseProjection):
(mean.shape[0],) + proj_scale_or_tril.shape[1:] (mean.shape[0],) + proj_scale_or_tril.shape[1:]
) )
if self.full_cov:
return {"loc": proj_mean, "scale_tril": proj_scale_or_tril}
else:
return {"loc": proj_mean, "scale": proj_scale_or_tril} return {"loc": proj_mean, "scale": proj_scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
@ -144,7 +156,12 @@ class KLProjection(BaseProjection):
return proj_scale_or_tril return proj_scale_or_tril
def _validate_inputs(self, policy_params, old_policy_params): def _validate_inputs(self, policy_params, old_policy_params):
"""Validate input parameters have correct format."""
if self.full_cov:
required_keys = ["loc", "scale_tril"]
else:
required_keys = ["loc", "scale"] required_keys = ["loc", "scale"]
for key in required_keys: for key in required_keys:
if key not in policy_params or key not in old_policy_params: if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters") raise KeyError(f"Missing required key '{key}' in policy parameters")

View File

@ -11,38 +11,6 @@ def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
""" """
return scale_tril return scale_tril
def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray],
q: Tuple[jnp.ndarray, jnp.ndarray],
scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]:
mean, scale_or_sqrt = p
mean_other, scale_or_sqrt_other = q
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
if scale_or_sqrt.ndim == mean.ndim: # Diagonal case
cov = scale_or_sqrt ** 2
cov_other = scale_or_sqrt_other ** 2
if scale_prec:
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
sqrt_inv_other = 1 / scale_or_sqrt_other
c = sqrt_inv_other ** 2 * cov
cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1)
else:
cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1)
else: # Full covariance case
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2))
cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2))
if scale_prec:
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity)
c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2)
cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
else:
cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
return mean_part, cov_part
class WassersteinProjection(BaseProjection): class WassersteinProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01, def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, scale_prec: bool = False, cov_bound: float = 0.01, scale_prec: bool = False,
@ -54,21 +22,33 @@ class WassersteinProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray], def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
mean = policy_params["loc"] assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
old_mean = old_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"])
mean_part, cov_part = gaussian_wasserstein_commutative( mean = policy_params["loc"] # shape: (batch_size, dim)
(mean, scale_or_sqrt), old_mean = old_policy_params["loc"]
(old_mean, old_scale_or_sqrt), scale = policy_params["scale"] # shape: (batch_size, dim)
self.scale_prec old_scale = old_policy_params["scale"]
original_shape = scale.shape # Store original shape for broadcasting back
if not self.contextual_std:
# Use only first batch element for scale
scale = scale[0] # shape: (dim,)
old_scale = old_scale[0] # shape: (dim,)
mean_part, scale_part = self._gaussian_wasserstein(
(mean, scale),
(old_mean, old_scale)
) )
proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part) proj_scale = self._scale_projection(scale, old_scale, scale_part)
return {"loc": proj_mean, "scale": proj_scale_or_sqrt} if not self.contextual_std:
# Broadcast single scale to all batch elements
proj_scale = jnp.broadcast_to(proj_scale[None, :], original_shape)
return {"loc": proj_mean, "scale": proj_scale}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray], def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray: proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
@ -76,10 +56,9 @@ class WassersteinProjection(BaseProjection):
proj_mean = proj_policy_params["loc"] proj_mean = proj_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"]) scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"]) proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
mean_part, cov_part = gaussian_wasserstein_commutative( mean_part, cov_part = self._gaussian_wasserstein(
(mean, scale_or_sqrt), (mean, scale_or_sqrt),
(proj_mean, proj_scale_or_sqrt), (proj_mean, proj_scale_or_sqrt)
self.scale_prec
) )
w2 = mean_part + cov_part w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff return w2.mean() * self.trust_region_coeff
@ -92,17 +71,29 @@ class WassersteinProjection(BaseProjection):
old_mean + diff * self.mean_bound / norm[..., None], old_mean + diff * self.mean_bound / norm[..., None],
mean) mean)
def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray, def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray: scale_part: jnp.ndarray) -> jnp.ndarray:
if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case """Project scale parameters (standard deviations for diagonal case)"""
diff = scale_or_sqrt - old_scale_or_sqrt diff = scale - old_scale
norm = jnp.sqrt(cov_part) norm = jnp.sqrt(scale_part)
if scale.ndim == 2: # Batched scale
norm = norm[..., None]
return jnp.where(norm > self.cov_bound, return jnp.where(norm > self.cov_bound,
old_scale_or_sqrt + diff * self.cov_bound / norm[..., None], old_scale + diff * self.cov_bound / norm,
scale_or_sqrt) scale)
else: # Full covariance case
diff = scale_or_sqrt - old_scale_or_sqrt def _gaussian_wasserstein(self, p, q):
norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True) mean, scale = p
return jnp.where(norm > self.cov_bound, mean_other, scale_other = q
old_scale_or_sqrt + diff * self.cov_bound / norm,
scale_or_sqrt) # Keep batch dimension by only summing over feature dimension
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
if scale.ndim == mean.ndim: # Batched scale
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
else: # Non-contextual scale (single scale for all batches)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
return mean_part, cov_part