Many fixes, that should have been multiple commits...

This commit is contained in:
Dominik Moritz Roth 2024-12-21 17:48:53 +01:00
parent e414d8c5b2
commit e0eb46e14c
4 changed files with 118 additions and 77 deletions

View File

@ -14,22 +14,29 @@ class BaseProjection(ABC):
@abstractmethod
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
pass
@abstractmethod
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Project policy parameters.
Args:
policy_params: Dictionary with:
- 'loc': mean parameters (batch_size, dim)
- 'scale': standard deviations (batch_size, dim) if full_cov=False
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
old_policy_params: Same format as policy_params
"""
pass
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Convert scale representation to covariance matrix."""
if not self.full_cov:
return jnp.diag(params["scale"] ** 2)
scale = params["scale"] # standard deviations
return jnp.square(scale) # diagonal covariance
else:
scale_tril = params["scale_tril"]
scale_tril = params["scale_tril"] # Cholesky factor
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray:
def _calc_scale_from_cov(self, cov: jnp.ndarray) -> jnp.ndarray:
"""Convert covariance matrix back to appropriate scale representation."""
if not self.full_cov:
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1))
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) # standard deviations
else:
return jnp.linalg.cholesky(cov)
return jnp.linalg.cholesky(cov) # Cholesky factor

View File

@ -15,16 +15,36 @@ class FrobeniusProjection(BaseProjection):
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
# Convert to covariance representation
cov = self._calc_covariance(policy_params)
old_cov = self._calc_covariance(old_policy_params)
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
if not self.contextual_std:
# Use only first batch element for scale, keeping batch dim
cov = cov[:1] # shape: (1, dim)
old_cov = old_cov[:1] # shape: (1, dim)
# Project in covariance space
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part)
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
return {"loc": proj_mean, "scale": scale_or_scale_tril}
# Convert back to appropriate scale representation
scale_or_tril = self._calc_scale_from_cov(proj_cov)
if not self.contextual_std:
# Broadcast scale to match original shape
if self.full_cov:
target_shape = policy_params["scale_tril"].shape
else:
target_shape = policy_params["scale"].shape
scale_or_tril = jnp.broadcast_to(scale_or_tril, target_shape)
# Return with correct key
if self.full_cov:
return {"loc": proj_mean, "scale_tril": scale_or_tril}
else:
return {"loc": proj_mean, "scale": scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
@ -72,7 +92,13 @@ class FrobeniusProjection(BaseProjection):
eta)
eta = jnp.maximum(-eta, eta)
new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
if self.full_cov:
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
else:
# For diagonal case, simple broadcasting
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
new_cov, cov)
return proj_cov

View File

@ -34,8 +34,17 @@ class KLProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
self._validate_inputs(policy_params, old_policy_params)
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"]
# Get the right scale parameter based on full_cov
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
if self.full_cov:
scale_or_tril = policy_params["scale_tril"]
old_scale_or_tril = old_policy_params["scale_tril"]
else:
scale_or_tril = policy_params["scale"]
old_scale_or_tril = old_policy_params["scale"]
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril),
(old_mean, old_scale_or_tril))
@ -54,7 +63,10 @@ class KLProjection(BaseProjection):
(mean.shape[0],) + proj_scale_or_tril.shape[1:]
)
return {"loc": proj_mean, "scale": proj_scale_or_tril}
if self.full_cov:
return {"loc": proj_mean, "scale_tril": proj_scale_or_tril}
else:
return {"loc": proj_mean, "scale": proj_scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
@ -144,7 +156,12 @@ class KLProjection(BaseProjection):
return proj_scale_or_tril
def _validate_inputs(self, policy_params, old_policy_params):
required_keys = ["loc", "scale"]
"""Validate input parameters have correct format."""
if self.full_cov:
required_keys = ["loc", "scale_tril"]
else:
required_keys = ["loc", "scale"]
for key in required_keys:
if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters")

View File

@ -11,38 +11,6 @@ def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
"""
return scale_tril
def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray],
q: Tuple[jnp.ndarray, jnp.ndarray],
scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]:
mean, scale_or_sqrt = p
mean_other, scale_or_sqrt_other = q
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
if scale_or_sqrt.ndim == mean.ndim: # Diagonal case
cov = scale_or_sqrt ** 2
cov_other = scale_or_sqrt_other ** 2
if scale_prec:
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
sqrt_inv_other = 1 / scale_or_sqrt_other
c = sqrt_inv_other ** 2 * cov
cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1)
else:
cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1)
else: # Full covariance case
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2))
cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2))
if scale_prec:
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity)
c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2)
cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
else:
cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
return mean_part, cov_part
class WassersteinProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, scale_prec: bool = False,
@ -54,21 +22,33 @@ class WassersteinProjection(BaseProjection):
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
mean = policy_params["loc"]
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
mean = policy_params["loc"] # shape: (batch_size, dim)
old_mean = old_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"])
scale = policy_params["scale"] # shape: (batch_size, dim)
old_scale = old_policy_params["scale"]
mean_part, cov_part = gaussian_wasserstein_commutative(
(mean, scale_or_sqrt),
(old_mean, old_scale_or_sqrt),
self.scale_prec
original_shape = scale.shape # Store original shape for broadcasting back
if not self.contextual_std:
# Use only first batch element for scale
scale = scale[0] # shape: (dim,)
old_scale = old_scale[0] # shape: (dim,)
mean_part, scale_part = self._gaussian_wasserstein(
(mean, scale),
(old_mean, old_scale)
)
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
proj_scale = self._scale_projection(scale, old_scale, scale_part)
return {"loc": proj_mean, "scale": proj_scale_or_sqrt}
if not self.contextual_std:
# Broadcast single scale to all batch elements
proj_scale = jnp.broadcast_to(proj_scale[None, :], original_shape)
return {"loc": proj_mean, "scale": proj_scale}
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
@ -76,10 +56,9 @@ class WassersteinProjection(BaseProjection):
proj_mean = proj_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
mean_part, cov_part = gaussian_wasserstein_commutative(
mean_part, cov_part = self._gaussian_wasserstein(
(mean, scale_or_sqrt),
(proj_mean, proj_scale_or_sqrt),
self.scale_prec
(proj_mean, proj_scale_or_sqrt)
)
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
@ -92,17 +71,29 @@ class WassersteinProjection(BaseProjection):
old_mean + diff * self.mean_bound / norm[..., None],
mean)
def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = jnp.sqrt(cov_part)
return jnp.where(norm > self.cov_bound,
old_scale_or_sqrt + diff * self.cov_bound / norm[..., None],
scale_or_sqrt)
else: # Full covariance case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True)
return jnp.where(norm > self.cov_bound,
old_scale_or_sqrt + diff * self.cov_bound / norm,
scale_or_sqrt)
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
scale_part: jnp.ndarray) -> jnp.ndarray:
"""Project scale parameters (standard deviations for diagonal case)"""
diff = scale - old_scale
norm = jnp.sqrt(scale_part)
if scale.ndim == 2: # Batched scale
norm = norm[..., None]
return jnp.where(norm > self.cov_bound,
old_scale + diff * self.cov_bound / norm,
scale)
def _gaussian_wasserstein(self, p, q):
mean, scale = p
mean_other, scale_other = q
# Keep batch dimension by only summing over feature dimension
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
if scale.ndim == mean.ndim: # Batched scale
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
else: # Non-contextual scale (single scale for all batches)
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
return mean_part, cov_part