Many fixes, that should have been multiple commits...
This commit is contained in:
parent
e414d8c5b2
commit
e0eb46e14c
@ -14,22 +14,29 @@ class BaseProjection(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
pass
|
"""Project policy parameters.
|
||||||
|
|
||||||
@abstractmethod
|
Args:
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
policy_params: Dictionary with:
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
- 'loc': mean parameters (batch_size, dim)
|
||||||
|
- 'scale': standard deviations (batch_size, dim) if full_cov=False
|
||||||
|
- 'scale_tril': Cholesky factor (batch_size, dim, dim) if full_cov=True
|
||||||
|
old_policy_params: Same format as policy_params
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
|
"""Convert scale representation to covariance matrix."""
|
||||||
if not self.full_cov:
|
if not self.full_cov:
|
||||||
return jnp.diag(params["scale"] ** 2)
|
scale = params["scale"] # standard deviations
|
||||||
|
return jnp.square(scale) # diagonal covariance
|
||||||
else:
|
else:
|
||||||
scale_tril = params["scale_tril"]
|
scale_tril = params["scale_tril"] # Cholesky factor
|
||||||
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
|
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
|
||||||
|
|
||||||
def _calc_scale_or_scale_tril(self, cov: jnp.ndarray) -> jnp.ndarray:
|
def _calc_scale_from_cov(self, cov: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""Convert covariance matrix back to appropriate scale representation."""
|
||||||
if not self.full_cov:
|
if not self.full_cov:
|
||||||
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1))
|
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) # standard deviations
|
||||||
else:
|
else:
|
||||||
return jnp.linalg.cholesky(cov)
|
return jnp.linalg.cholesky(cov) # Cholesky factor
|
@ -15,16 +15,36 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
mean = policy_params["loc"]
|
mean = policy_params["loc"]
|
||||||
old_mean = old_policy_params["loc"]
|
old_mean = old_policy_params["loc"]
|
||||||
|
|
||||||
|
# Convert to covariance representation
|
||||||
cov = self._calc_covariance(policy_params)
|
cov = self._calc_covariance(policy_params)
|
||||||
old_cov = self._calc_covariance(old_policy_params)
|
old_cov = self._calc_covariance(old_policy_params)
|
||||||
|
|
||||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
if not self.contextual_std:
|
||||||
|
# Use only first batch element for scale, keeping batch dim
|
||||||
|
cov = cov[:1] # shape: (1, dim)
|
||||||
|
old_cov = old_cov[:1] # shape: (1, dim)
|
||||||
|
|
||||||
|
# Project in covariance space
|
||||||
|
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||||
|
|
||||||
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
|
# Convert back to appropriate scale representation
|
||||||
return {"loc": proj_mean, "scale": scale_or_scale_tril}
|
scale_or_tril = self._calc_scale_from_cov(proj_cov)
|
||||||
|
|
||||||
|
if not self.contextual_std:
|
||||||
|
# Broadcast scale to match original shape
|
||||||
|
if self.full_cov:
|
||||||
|
target_shape = policy_params["scale_tril"].shape
|
||||||
|
else:
|
||||||
|
target_shape = policy_params["scale"].shape
|
||||||
|
scale_or_tril = jnp.broadcast_to(scale_or_tril, target_shape)
|
||||||
|
|
||||||
|
# Return with correct key
|
||||||
|
if self.full_cov:
|
||||||
|
return {"loc": proj_mean, "scale_tril": scale_or_tril}
|
||||||
|
else:
|
||||||
|
return {"loc": proj_mean, "scale": scale_or_tril}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
@ -72,7 +92,13 @@ class FrobeniusProjection(BaseProjection):
|
|||||||
eta)
|
eta)
|
||||||
eta = jnp.maximum(-eta, eta)
|
eta = jnp.maximum(-eta, eta)
|
||||||
|
|
||||||
new_cov = (cov + jnp.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
if self.full_cov:
|
||||||
proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
|
new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
||||||
|
else:
|
||||||
|
# For diagonal case, simple broadcasting
|
||||||
|
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
||||||
|
|
||||||
|
proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None],
|
||||||
|
new_cov, cov)
|
||||||
|
|
||||||
return proj_cov
|
return proj_cov
|
@ -34,8 +34,17 @@ class KLProjection(BaseProjection):
|
|||||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
self._validate_inputs(policy_params, old_policy_params)
|
self._validate_inputs(policy_params, old_policy_params)
|
||||||
mean, scale_or_tril = policy_params["loc"], policy_params["scale"]
|
|
||||||
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params["scale"]
|
# Get the right scale parameter based on full_cov
|
||||||
|
mean = policy_params["loc"]
|
||||||
|
old_mean = old_policy_params["loc"]
|
||||||
|
|
||||||
|
if self.full_cov:
|
||||||
|
scale_or_tril = policy_params["scale_tril"]
|
||||||
|
old_scale_or_tril = old_policy_params["scale_tril"]
|
||||||
|
else:
|
||||||
|
scale_or_tril = policy_params["scale"]
|
||||||
|
old_scale_or_tril = old_policy_params["scale"]
|
||||||
|
|
||||||
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril),
|
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril),
|
||||||
(old_mean, old_scale_or_tril))
|
(old_mean, old_scale_or_tril))
|
||||||
@ -54,7 +63,10 @@ class KLProjection(BaseProjection):
|
|||||||
(mean.shape[0],) + proj_scale_or_tril.shape[1:]
|
(mean.shape[0],) + proj_scale_or_tril.shape[1:]
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
if self.full_cov:
|
||||||
|
return {"loc": proj_mean, "scale_tril": proj_scale_or_tril}
|
||||||
|
else:
|
||||||
|
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
@ -144,7 +156,12 @@ class KLProjection(BaseProjection):
|
|||||||
return proj_scale_or_tril
|
return proj_scale_or_tril
|
||||||
|
|
||||||
def _validate_inputs(self, policy_params, old_policy_params):
|
def _validate_inputs(self, policy_params, old_policy_params):
|
||||||
required_keys = ["loc", "scale"]
|
"""Validate input parameters have correct format."""
|
||||||
|
if self.full_cov:
|
||||||
|
required_keys = ["loc", "scale_tril"]
|
||||||
|
else:
|
||||||
|
required_keys = ["loc", "scale"]
|
||||||
|
|
||||||
for key in required_keys:
|
for key in required_keys:
|
||||||
if key not in policy_params or key not in old_policy_params:
|
if key not in policy_params or key not in old_policy_params:
|
||||||
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
||||||
|
@ -11,38 +11,6 @@ def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
|
|||||||
"""
|
"""
|
||||||
return scale_tril
|
return scale_tril
|
||||||
|
|
||||||
def gaussian_wasserstein_commutative(p: Tuple[jnp.ndarray, jnp.ndarray],
|
|
||||||
q: Tuple[jnp.ndarray, jnp.ndarray],
|
|
||||||
scale_prec: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
||||||
mean, scale_or_sqrt = p
|
|
||||||
mean_other, scale_or_sqrt_other = q
|
|
||||||
|
|
||||||
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
|
|
||||||
|
|
||||||
if scale_or_sqrt.ndim == mean.ndim: # Diagonal case
|
|
||||||
cov = scale_or_sqrt ** 2
|
|
||||||
cov_other = scale_or_sqrt_other ** 2
|
|
||||||
if scale_prec:
|
|
||||||
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
|
|
||||||
sqrt_inv_other = 1 / scale_or_sqrt_other
|
|
||||||
c = sqrt_inv_other ** 2 * cov
|
|
||||||
cov_part = jnp.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, axis=-1)
|
|
||||||
else:
|
|
||||||
cov_part = jnp.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, axis=-1)
|
|
||||||
else: # Full covariance case
|
|
||||||
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
|
|
||||||
cov = jnp.matmul(scale_or_sqrt, jnp.swapaxes(scale_or_sqrt, -1, -2))
|
|
||||||
cov_other = jnp.matmul(scale_or_sqrt_other, jnp.swapaxes(scale_or_sqrt_other, -1, -2))
|
|
||||||
if scale_prec:
|
|
||||||
identity = jnp.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype)
|
|
||||||
sqrt_inv_other = jnp.linalg.solve(scale_or_sqrt_other, identity)
|
|
||||||
c = sqrt_inv_other @ cov @ jnp.swapaxes(sqrt_inv_other, -1, -2)
|
|
||||||
cov_part = jnp.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
|
|
||||||
else:
|
|
||||||
cov_part = jnp.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
|
|
||||||
|
|
||||||
return mean_part, cov_part
|
|
||||||
|
|
||||||
class WassersteinProjection(BaseProjection):
|
class WassersteinProjection(BaseProjection):
|
||||||
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
||||||
cov_bound: float = 0.01, scale_prec: bool = False,
|
cov_bound: float = 0.01, scale_prec: bool = False,
|
||||||
@ -54,21 +22,33 @@ class WassersteinProjection(BaseProjection):
|
|||||||
|
|
||||||
def project(self, policy_params: Dict[str, jnp.ndarray],
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
||||||
mean = policy_params["loc"]
|
assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
|
||||||
old_mean = old_policy_params["loc"]
|
|
||||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
|
||||||
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params["scale"])
|
|
||||||
|
|
||||||
mean_part, cov_part = gaussian_wasserstein_commutative(
|
mean = policy_params["loc"] # shape: (batch_size, dim)
|
||||||
(mean, scale_or_sqrt),
|
old_mean = old_policy_params["loc"]
|
||||||
(old_mean, old_scale_or_sqrt),
|
scale = policy_params["scale"] # shape: (batch_size, dim)
|
||||||
self.scale_prec
|
old_scale = old_policy_params["scale"]
|
||||||
|
|
||||||
|
original_shape = scale.shape # Store original shape for broadcasting back
|
||||||
|
|
||||||
|
if not self.contextual_std:
|
||||||
|
# Use only first batch element for scale
|
||||||
|
scale = scale[0] # shape: (dim,)
|
||||||
|
old_scale = old_scale[0] # shape: (dim,)
|
||||||
|
|
||||||
|
mean_part, scale_part = self._gaussian_wasserstein(
|
||||||
|
(mean, scale),
|
||||||
|
(old_mean, old_scale)
|
||||||
)
|
)
|
||||||
|
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
|
proj_scale = self._scale_projection(scale, old_scale, scale_part)
|
||||||
|
|
||||||
return {"loc": proj_mean, "scale": proj_scale_or_sqrt}
|
if not self.contextual_std:
|
||||||
|
# Broadcast single scale to all batch elements
|
||||||
|
proj_scale = jnp.broadcast_to(proj_scale[None, :], original_shape)
|
||||||
|
|
||||||
|
return {"loc": proj_mean, "scale": proj_scale}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
||||||
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
||||||
@ -76,10 +56,9 @@ class WassersteinProjection(BaseProjection):
|
|||||||
proj_mean = proj_policy_params["loc"]
|
proj_mean = proj_policy_params["loc"]
|
||||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params["scale"])
|
||||||
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
|
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params["scale"])
|
||||||
mean_part, cov_part = gaussian_wasserstein_commutative(
|
mean_part, cov_part = self._gaussian_wasserstein(
|
||||||
(mean, scale_or_sqrt),
|
(mean, scale_or_sqrt),
|
||||||
(proj_mean, proj_scale_or_sqrt),
|
(proj_mean, proj_scale_or_sqrt)
|
||||||
self.scale_prec
|
|
||||||
)
|
)
|
||||||
w2 = mean_part + cov_part
|
w2 = mean_part + cov_part
|
||||||
return w2.mean() * self.trust_region_coeff
|
return w2.mean() * self.trust_region_coeff
|
||||||
@ -92,17 +71,29 @@ class WassersteinProjection(BaseProjection):
|
|||||||
old_mean + diff * self.mean_bound / norm[..., None],
|
old_mean + diff * self.mean_bound / norm[..., None],
|
||||||
mean)
|
mean)
|
||||||
|
|
||||||
def _cov_projection(self, scale_or_sqrt: jnp.ndarray, old_scale_or_sqrt: jnp.ndarray,
|
def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray,
|
||||||
cov_part: jnp.ndarray) -> jnp.ndarray:
|
scale_part: jnp.ndarray) -> jnp.ndarray:
|
||||||
if scale_or_sqrt.ndim == old_scale_or_sqrt.ndim == 2: # Diagonal case
|
"""Project scale parameters (standard deviations for diagonal case)"""
|
||||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
diff = scale - old_scale
|
||||||
norm = jnp.sqrt(cov_part)
|
norm = jnp.sqrt(scale_part)
|
||||||
return jnp.where(norm > self.cov_bound,
|
|
||||||
old_scale_or_sqrt + diff * self.cov_bound / norm[..., None],
|
if scale.ndim == 2: # Batched scale
|
||||||
scale_or_sqrt)
|
norm = norm[..., None]
|
||||||
else: # Full covariance case
|
|
||||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
return jnp.where(norm > self.cov_bound,
|
||||||
norm = jnp.linalg.norm(diff, axis=(-2, -1), keepdims=True)
|
old_scale + diff * self.cov_bound / norm,
|
||||||
return jnp.where(norm > self.cov_bound,
|
scale)
|
||||||
old_scale_or_sqrt + diff * self.cov_bound / norm,
|
|
||||||
scale_or_sqrt)
|
def _gaussian_wasserstein(self, p, q):
|
||||||
|
mean, scale = p
|
||||||
|
mean_other, scale_other = q
|
||||||
|
|
||||||
|
# Keep batch dimension by only summing over feature dimension
|
||||||
|
mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1) # -> (batch_size,)
|
||||||
|
|
||||||
|
if scale.ndim == mean.ndim: # Batched scale
|
||||||
|
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
|
||||||
|
else: # Non-contextual scale (single scale for all batches)
|
||||||
|
cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
|
||||||
|
|
||||||
|
return mean_part, cov_part
|
Loading…
Reference in New Issue
Block a user