Many fixes, that should have been multiple commits...
This commit is contained in:
parent
e414d8c5b2
commit
e0eb46e14c
@ -14,22 +14,29 @@ class BaseProjection(ABC):
|
||||
@abstractmethod
|
||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
"""Project policy parameters.
|
||||
|
||||
Args:
|
||||
policy_params: Dictionary with:
|
||||
- 'loc': mean parameters (batch_size, dim)
|
||||
- 'scale': standard deviations (batch_size, dim) if full_cov=False
|
||||
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
|
||||
old_policy_params: Same format as policy_params
|
||||
"""
|
||||
pass
|
||||
|
||||
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
"""Convert scale representation to covariance matrix."""
|
||||
if not self.full_cov:
|
||||
return jnp.diag(params["scale"] ** 2)
|
||||
scale = params["scale"] # standard deviations
|
||||
return jnp.square(scale) # diagonal covariance
|
||||
else:
|
||||
scale_tril = params["scale_tril"]
|
||||
scale_tril = params["scale_tril"] # Cholesky factor
|
||||
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
|
||||
|
||||
def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray:
|
||||
def _calc_scale_from_cov(self, cov: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Convert covariance matrix back to appropriate scale representation."""
|
||||
if not self.full_cov:
|
||||
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1))
|
||||
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) # standard deviations
|
||||
else:
|
||||
return jnp.linalg.cholesky(cov)
|
||||
return jnp.linalg.cholesky(cov) # Cholesky factor
|
@ -15,16 +15,36 @@ class FrobeniusProjection(BaseProjection):
|
||||
mean = policy_params["loc"]
|
||||
old_mean = old_policy_params["loc"]
|
||||
|
||||
# Convert to covariance representation
|
||||
cov = self._calc_covariance(policy_params)
|
||||
old_cov = self._calc_covariance(old_policy_params)
|
||||
|
||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||
if not self.contextual_std:
|
||||
# Use only first batch element for scale, keeping batch dim
|
||||
cov = cov[:1] # shape: (1, dim)
|
||||
old_cov = old_cov[:1] # shape: (1, dim)
|
||||
|
||||
# Project in covariance space
|
||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||
|
||||
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
|
||||
return {"loc": proj_mean, "scale": scale_or_scale_tril}
|
||||
# Convert back to appropriate scale representation
|
||||
scale_or_tril = self._calc_scale_from_cov(proj_cov)
|
||||
|
||||
if not self.contextual_std:
|
||||
# Broadcast scale to match original shape
|
||||
if self.full_cov:
|
||||
target_shape = policy_params["scale_tril"].shape
|
||||
else:
|
||||
target_shape = policy_params["scale"].shape
|
||||
scale_or_tril = jnp.broadcast_to(scale_or_tril, target_shape)
|
||||
|
||||
# Return with correct key
|
||||
if self.full_cov:
|
||||
return {"loc": proj_mean, "scale_tril": scale_or_tril}
|
||||
else:
|
||||
return {"loc": proj_mean, "scale": scale_or_tril}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
@ -72,7 +92,13 @@ class FrobeniusProjection(BaseProjection):
|
||||
eta)
|
||||
eta = jnp.maximum(-eta, eta)
|
||||
|
||||
new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
||||
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
|
||||
if self.full_cov:
|
||||
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
||||
else:
|
||||
# For diagonal case, simple broadcasting
|
||||
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
||||
|
||||
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
|
||||
new_cov, cov)
|
||||
|
||||
return proj_cov
|
@ -34,8 +34,17 @@ class KLProjection(BaseProjection):
|
||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||
self._validate_inputs(policy_params, old_policy_params)
|
||||
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
|
||||
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"]
|
||||
|
||||
# Get the right scale parameter based on full_cov
|
||||
mean = policy_params["loc"]
|
||||
old_mean = old_policy_params["loc"]
|
||||
|
||||
if self.full_cov:
|
||||
scale_or_tril = policy_params["scale_tril"]
|
||||
old_scale_or_tril = old_policy_params["scale_tril"]
|
||||
else:
|
||||
scale_or_tril = policy_params["scale"]
|
||||
old_scale_or_tril = old_policy_params["scale"]
|
||||
|
||||
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril),
|
||||
(old_mean, old_scale_or_tril))
|
||||
@ -54,7 +63,10 @@ class KLProjection(BaseProjection):
|
||||
(mean.shape[0],) + proj_scale_or_tril.shape[1:]
|
||||
)
|
||||
|
||||
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
||||
if self.full_cov:
|
||||
return {"loc": proj_mean, "scale_tril": proj_scale_or_tril}
|
||||
else:
|
||||
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
@ -144,7 +156,12 @@ class KLProjection(BaseProjection):
|
||||
return proj_scale_or_tril
|
||||
|
||||
def _validate_inputs(self, policy_params, old_policy_params):
|
||||
required_keys = ["loc", "scale"]
|
||||
"""Validate input parameters have correct format."""
|
||||
if self.full_cov:
|
||||
required_keys = ["loc", "scale_tril"]
|
||||
else:
|
||||
required_keys = ["loc", "scale"]
|
||||
|
||||
for key in required_keys:
|
||||
if key not in policy_params or key not in old_policy_params:
|
||||
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
||||
|
@ -11,38 +11,6 @@ def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
||||
"""
|
||||
return scale_tril
|
||||
|
||||
def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray],
|
||||
q: Tuple[jnp.ndarray, jnp.ndarray],
|
||||
scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
mean, scale_or_sqrt = p
|
||||
mean_other, scale_or_sqrt_other = q
|
||||
|
||||
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
|
||||
|
||||
if scale_or_sqrt.ndim == mean.ndim: # Diagonal case
|
||||
cov = scale_or_sqrt ** 2
|
||||
cov_other = scale_or_sqrt_other ** 2
|
||||
if scale_prec:
|
||||
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
|
||||
sqrt_inv_other = 1 / scale_or_sqrt_other
|
||||
c = sqrt_inv_other ** 2 * cov
|
||||
cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1)
|
||||
else:
|
||||
cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1)
|
||||
else: # Full covariance case
|
||||
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
|
||||
cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2))
|
||||
cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2))
|
||||
if scale_prec:
|
||||
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
|
||||
sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity)
|
||||
c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2)
|
||||
cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
|
||||
else:
|
||||
cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
class WassersteinProjection(BaseProjection):
|
||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||
cov_bound: float = 0.01, scale_prec: bool = False,
|
||||
@ -54,21 +22,33 @@ class WassersteinProjection(BaseProjection):
|
||||
|
||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||
mean = policy_params["loc"]
|
||||
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
|
||||
|
||||
mean = policy_params["loc"] # shape: (batch_size, dim)
|
||||
old_mean = old_policy_params["loc"]
|
||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
||||
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"])
|
||||
scale = policy_params["scale"] # shape: (batch_size, dim)
|
||||
old_scale = old_policy_params["scale"]
|
||||
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(
|
||||
(mean, scale_or_sqrt),
|
||||
(old_mean, old_scale_or_sqrt),
|
||||
self.scale_prec
|
||||
original_shape = scale.shape # Store original shape for broadcasting back
|
||||
|
||||
if not self.contextual_std:
|
||||
# Use only first batch element for scale
|
||||
scale = scale[0] # shape: (dim,)
|
||||
old_scale = old_scale[0] # shape: (dim,)
|
||||
|
||||
mean_part, scale_part = self._gaussian_wasserstein(
|
||||
(mean, scale),
|
||||
(old_mean, old_scale)
|
||||
)
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
|
||||
proj_scale = self._scale_projection(scale, old_scale, scale_part)
|
||||
|
||||
return {"loc": proj_mean, "scale": proj_scale_or_sqrt}
|
||||
if not self.contextual_std:
|
||||
# Broadcast single scale to all batch elements
|
||||
proj_scale = jnp.broadcast_to(proj_scale[None, :], original_shape)
|
||||
|
||||
return {"loc": proj_mean, "scale": proj_scale}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||
@ -76,10 +56,9 @@ class WassersteinProjection(BaseProjection):
|
||||
proj_mean = proj_policy_params["loc"]
|
||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
||||
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(
|
||||
mean_part, cov_part = self._gaussian_wasserstein(
|
||||
(mean, scale_or_sqrt),
|
||||
(proj_mean, proj_scale_or_sqrt),
|
||||
self.scale_prec
|
||||
(proj_mean, proj_scale_or_sqrt)
|
||||
)
|
||||
w2 = mean_part + cov_part
|
||||
return w2.mean() * self.trust_region_coeff
|
||||
@ -92,17 +71,29 @@ class WassersteinProjection(BaseProjection):
|
||||
old_mean + diff * self.mean_bound / norm[..., None],
|
||||
mean)
|
||||
|
||||
def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray,
|
||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
||||
if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case
|
||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||
norm = jnp.sqrt(cov_part)
|
||||
return jnp.where(norm > self.cov_bound,
|
||||
old_scale_or_sqrt + diff * self.cov_bound / norm[..., None],
|
||||
scale_or_sqrt)
|
||||
else: # Full covariance case
|
||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||
norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True)
|
||||
return jnp.where(norm > self.cov_bound,
|
||||
old_scale_or_sqrt + diff * self.cov_bound / norm,
|
||||
scale_or_sqrt)
|
||||
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
||||
scale_part: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Project scale parameters (standard deviations for diagonal case)"""
|
||||
diff = scale - old_scale
|
||||
norm = jnp.sqrt(scale_part)
|
||||
|
||||
if scale.ndim == 2: # Batched scale
|
||||
norm = norm[..., None]
|
||||
|
||||
return jnp.where(norm > self.cov_bound,
|
||||
old_scale + diff * self.cov_bound / norm,
|
||||
scale)
|
||||
|
||||
def _gaussian_wasserstein(self, p, q):
|
||||
mean, scale = p
|
||||
mean_other, scale_other = q
|
||||
|
||||
# Keep batch dimension by only summing over feature dimension
|
||||
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
|
||||
|
||||
if scale.ndim == mean.ndim: # Batched scale
|
||||
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
|
||||
else: # Non-contextual scale (single scale for all batches)
|
||||
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
|
||||
|
||||
return mean_part, cov_part
|
Loading…
Reference in New Issue
Block a user