From ecf0b72e8854bdaef6ca0bf6e6a794cee192d2ea Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 22 Jan 2025 14:03:23 +0100 Subject: [PATCH] Port soem fixes / additions learned from implementign itpal_jax --- fancy_rl/projections/base_projection.py | 25 ++- fancy_rl/projections/frobenius_projection.py | 59 ++++--- fancy_rl/projections/kl_projection.py | 108 ++++++++----- .../projections/wasserstein_projection.py | 151 +++++++++++++----- 4 files changed, 240 insertions(+), 103 deletions(-) diff --git a/fancy_rl/projections/base_projection.py b/fancy_rl/projections/base_projection.py index 3ceb5d1..2a81ce3 100644 --- a/fancy_rl/projections/base_projection.py +++ b/fancy_rl/projections/base_projection.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import torch from torch import nn -from typing import Dict, List +from typing import Dict, List, Tuple class BaseProjection(nn.Module, ABC): def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): @@ -68,4 +68,25 @@ class BaseProjection(nn.Module, ABC): if not self.full_cov: return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1)) else: - return torch.linalg.cholesky(cov) \ No newline at end of file + return torch.linalg.cholesky(cov) + + def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: + """Project mean based on the Mahalanobis objective and trust region. + + Args: + mean: Current mean vectors + old_mean: Old mean vectors + mean_part: Mahalanobis/Euclidean distance between the two mean vectors + + Returns: + Projected mean that satisfies the trust region + """ + mask = mean_part > self.mean_bound + omega = torch.ones_like(mean_part, device=mean.device) + omega = torch.where(mask, torch.sqrt(mean_part / self.mean_bound) - 1., omega) + omega = torch.maximum(-omega, omega)[..., None] + + # Use matrix operations instead of boolean indexing + m = (mean + omega * old_mean) / (1. + omega + 1e-16) + mask_matrix = mask[..., None].to(mean.dtype) + return mask_matrix * m + (1 - mask_matrix) * mean \ No newline at end of file diff --git a/fancy_rl/projections/frobenius_projection.py b/fancy_rl/projections/frobenius_projection.py index 2a92c04..e2abfd4 100644 --- a/fancy_rl/projections/frobenius_projection.py +++ b/fancy_rl/projections/frobenius_projection.py @@ -1,7 +1,7 @@ import torch from .base_projection import BaseProjection from tensordict.nn import TensorDictModule -from typing import Dict +from typing import Dict, Tuple class FrobeniusProjection(BaseProjection): def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): @@ -12,16 +12,23 @@ class FrobeniusProjection(BaseProjection): mean = policy_params["loc"] old_mean = old_policy_params["loc"] + # Convert to covariance representation cov = self._calc_covariance(policy_params) old_cov = self._calc_covariance(old_policy_params) - mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) + if not self.contextual_std: + cov = cov[:1] + old_cov = old_cov[:1] + mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_cov = self._cov_projection(cov, old_cov, cov_part) - scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov) - return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril} + scale_or_tril = self._calc_scale_or_scale_tril(proj_cov) + if not self.contextual_std: + scale_or_tril = scale_or_tril.expand(mean.shape[0], *scale_or_tril.shape[1:]) + + return {"loc": proj_mean, self.out_keys[1]: scale_or_tril} def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: mean = policy_params["loc"] @@ -35,34 +42,48 @@ class FrobeniusProjection(BaseProjection): return (mean_diff + cov_diff).mean() * self.trust_region_coeff - def _gaussian_frobenius(self, p, q): + def _gaussian_frobenius(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: mean, cov = p old_mean, old_cov = q if self.scale_prec: - prec_old = torch.inverse(old_cov) - mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1) - cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1] + if self.full_cov: + # Use triangular solve instead of inverse for stability + diff = mean - old_mean + solved = torch.triangular_solve(diff.unsqueeze(-1), old_cov, upper=False)[0].squeeze(-1) + mean_part = torch.sum(torch.square(solved), dim=-1) + else: + # Diagonal case - direct division is stable + mean_part = torch.sum(torch.square((mean - old_mean) / old_cov), dim=-1) else: mean_part = torch.sum(torch.square(mean - old_mean), dim=-1) - cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1)) + + # Covariance part + if self.full_cov: + diff = old_cov - cov + cov_part = torch.sum(torch.square(diff), dim=(-2, -1)) + else: + diff = old_cov - cov + cov_part = torch.sum(torch.square(diff), dim=-1) return mean_part, cov_part - def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: - diff = mean - old_mean - norm = torch.sqrt(mean_part) - return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean) - def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: - batch_shape = cov.shape[:-2] + batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1] cov_mask = cov_part > self.cov_bound eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device) - eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1. - eta = torch.max(-eta, eta) + eta = torch.where(cov_mask, torch.sqrt(cov_part / self.cov_bound) - 1., eta) + eta = torch.maximum(-eta, eta) - new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] - proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov) + if self.full_cov: + new_cov = (cov + torch.einsum('...,...ij->...ij', eta, old_cov)) / \ + (1. + eta + 1e-16)[..., None, None] + mask_matrix = cov_mask[..., None, None].to(cov.dtype) + proj_cov = torch.where(mask_matrix, new_cov, cov) + else: + new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None] + mask_matrix = cov_mask[..., None].to(cov.dtype) + proj_cov = torch.where(mask_matrix, new_cov, cov) return proj_cov \ No newline at end of file diff --git a/fancy_rl/projections/kl_projection.py b/fancy_rl/projections/kl_projection.py index cd1a5fb..aaf7f7a 100644 --- a/fancy_rl/projections/kl_projection.py +++ b/fancy_rl/projections/kl_projection.py @@ -19,8 +19,17 @@ class KLProjection(BaseProjection): super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] - old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]] + self._validate_inputs(policy_params, old_policy_params) + + mean = policy_params["loc"] + old_mean = old_policy_params["loc"] + + if self.full_cov: + scale_or_tril = policy_params["scale_tril"] + old_scale_or_tril = old_policy_params["scale_tril"] + else: + scale_or_tril = policy_params["scale"] + old_scale_or_tril = old_policy_params["scale"] mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril)) @@ -35,11 +44,22 @@ class KLProjection(BaseProjection): if not self.contextual_std: proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:]) - return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril} + if self.full_cov: + return {"loc": proj_mean, "scale_tril": proj_scale_or_tril} + else: + return {"loc": proj_mean, "scale": proj_scale_or_tril} def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: - mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] - proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]] + mean = policy_params["loc"] + proj_mean = proj_policy_params["loc"] + + if self.full_cov: + scale_or_tril = policy_params["scale_tril"] + proj_scale_or_tril = proj_policy_params["scale_tril"] + else: + scale_or_tril = policy_params["scale"] + proj_scale_or_tril = proj_policy_params["scale"] + kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) return kl.mean() * self.trust_region_coeff @@ -54,7 +74,9 @@ class KLProjection(BaseProjection): det_term_other = self._log_determinant(scale_or_tril_other) if self.full_cov: - trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False)) + trace_part = self._batched_trace_square( + torch.triangular_solve(scale_or_tril, scale_or_tril_other, upper=False)[0] + ) else: trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1) @@ -65,62 +87,60 @@ class KLProjection(BaseProjection): def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor: diff = x - y if self.full_cov: - return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1) + solved = torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0] + return torch.sum(torch.square(solved.squeeze(-1)), dim=-1) else: return torch.sum(torch.square(diff / scale_or_tril), dim=-1) def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor: if self.full_cov: - return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1) + return 2 * torch.sum(torch.log(torch.diagonal(scale_or_tril, dim1=-2, dim2=-1)), dim=-1) else: - return 2 * torch.log(scale_or_tril).sum(-1) + return 2 * torch.sum(torch.log(scale_or_tril), dim=-1) - def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: - return torch.sum(x.pow(2), dim=(-2, -1)) - - def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: - return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1) + def _batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: + return torch.sum(x ** 2, dim=(-2, -1)) def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: if self.full_cov: cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2)) old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2)) else: - cov = scale_or_tril.pow(2) - old_cov = old_scale_or_tril.pow(2) + cov = scale_or_tril ** 2 + old_cov = old_scale_or_tril ** 2 mask = cov_part > self.cov_bound - proj_scale_or_tril = torch.zeros_like(scale_or_tril) - proj_scale_or_tril[~mask] = scale_or_tril[~mask] - - try: - if mask.any(): - if self.full_cov: - proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound) - is_invalid = proj_cov.mean([-2, -1]).isnan() & mask - if is_invalid.any(): - proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid] - mask &= ~is_invalid - proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) - failed_mask = failed_mask.bool() - if failed_mask.any(): - proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask] - else: - proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound) - is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask - if is_invalid.any(): - proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid] - mask &= ~is_invalid - proj_scale_or_tril[mask] = proj_cov[mask].sqrt() - except Exception as e: - import logging - logging.error('Projection failed, taking old scale_or_tril for projection.') - print("Projection failed, taking old scale_or_tril for projection.") - proj_scale_or_tril = old_scale_or_tril - raise e + proj_scale_or_tril = scale_or_tril # Start with original scale + + if mask.any(): + if self.full_cov: + proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) + is_invalid = torch.isnan(proj_cov.mean(dim=(-2, -1))) + proj_scale_or_tril = torch.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril) + mask = mask & ~is_invalid + chol = torch.linalg.cholesky(proj_cov) + proj_scale_or_tril = torch.where(mask[..., None, None], chol, proj_scale_or_tril) + else: + proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) + is_invalid = (torch.isnan(proj_cov.mean(dim=-1)) | + torch.isinf(proj_cov.mean(dim=-1)) | + (proj_cov.min(dim=-1).values < 0)) + proj_scale_or_tril = torch.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril) + mask = mask & ~is_invalid + proj_scale_or_tril = torch.where(mask[..., None], torch.sqrt(proj_cov), scale_or_tril) return proj_scale_or_tril + def _validate_inputs(self, policy_params, old_policy_params): + if self.full_cov: + required_keys = ["loc", "scale_tril"] + else: + required_keys = ["loc", "scale"] + + for key in required_keys: + if key not in policy_params or key not in old_policy_params: + raise KeyError(f"Missing required key '{key}' in policy parameters") + class KLProjectionGradFunctionCovOnly(torch.autograd.Function): projection_op = None diff --git a/fancy_rl/projections/wasserstein_projection.py b/fancy_rl/projections/wasserstein_projection.py index d19dc84..e00602b 100644 --- a/fancy_rl/projections/wasserstein_projection.py +++ b/fancy_rl/projections/wasserstein_projection.py @@ -1,6 +1,5 @@ import torch from .base_projection import BaseProjection -from tensordict.nn import TensorDictModule from typing import Dict, Tuple def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor: @@ -12,75 +11,151 @@ def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor: """ return scale_tril -def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor], - q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]: +def gaussian_wasserstein_commutative(p: Tuple[torch.Tensor, torch.Tensor], + q: Tuple[torch.Tensor, torch.Tensor], + scale_prec: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: mean, scale_or_sqrt = p mean_other, scale_or_sqrt_other = q mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) if scale_or_sqrt.dim() == mean.dim(): # Diagonal case - cov = scale_or_sqrt.pow(2) - cov_other = scale_or_sqrt_other.pow(2) if scale_prec: - identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) - sqrt_inv_other = 1 / scale_or_sqrt_other - c = sqrt_inv_other.pow(2) * cov - cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1) + # More stable implementation for precision scaling + scale_part = torch.sum( + scale_or_sqrt_other**2 + scale_or_sqrt**2 - + 2 * scale_or_sqrt_other * scale_or_sqrt, + dim=-1 + ) else: - cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1) + # Standard W2 for diagonal case + scale_part = torch.sum( + scale_or_sqrt_other**2 + scale_or_sqrt**2 - + 2 * scale_or_sqrt_other * scale_or_sqrt, + dim=-1 + ) else: # Full covariance case # Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition - cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2)) - cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2)) if scale_prec: + # More stable implementation using triangular solve identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) - sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity) - c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2) - cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt) + sqrt_inv_other = torch.triangular_solve(identity, scale_or_sqrt_other, upper=False)[0] + c = torch.matmul(sqrt_inv_other, scale_or_sqrt) + scale_part = torch.sum(identity**2 + c**2 - 2 * c, dim=(-2, -1)) else: - cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt) + # Standard W2 for full covariance + scale_part = torch.sum( + scale_or_sqrt_other**2 + scale_or_sqrt**2 - + 2 * torch.matmul(scale_or_sqrt_other, scale_or_sqrt.transpose(-1, -2)), + dim=(-2, -1) + ) - return mean_part, cov_part + return mean_part, scale_part class WassersteinProjection(BaseProjection): - def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): - super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, + mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, + contextual_std: bool = True): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, + mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) self.scale_prec = scale_prec def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: mean = policy_params["loc"] old_mean = old_policy_params["loc"] - scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]]) - old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]]) + + # scale_tril is already Cholesky of matrix sqrt + scale_sqrt = policy_params[self.in_keys[1]] + old_scale_sqrt = old_policy_params[self.in_keys[1]] - mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec) + if not self.contextual_std: + scale_sqrt = scale_sqrt[:1] + old_scale_sqrt = old_scale_sqrt[:1] + + mean_part, scale_part = self._gaussian_wasserstein( + (mean, scale_sqrt), + (old_mean, old_scale_sqrt) + ) proj_mean = self._mean_projection(mean, old_mean, mean_part) - proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part) + proj_scale_sqrt = self._scale_projection(scale_sqrt, old_scale_sqrt, scale_part) - return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt} + if not self.contextual_std: + proj_scale_sqrt = proj_scale_sqrt.expand(mean.shape[0], *proj_scale_sqrt.shape[1:]) + + return {"loc": proj_mean, self.out_keys[1]: proj_scale_sqrt} + + def _gaussian_wasserstein(self, p: Tuple[torch.Tensor, torch.Tensor], + q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + mean, scale_sqrt = p + mean_other, scale_sqrt_other = q + + mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) + + if not self.full_cov: + # Diagonal case is simpler + scale_part = torch.sum( + scale_sqrt_other**2 + scale_sqrt**2 - + 2 * scale_sqrt_other * scale_sqrt, + dim=-1 + ) + else: + # Full covariance case uses matrix operations + scale_part = torch.sum( + scale_sqrt_other**2 + scale_sqrt**2 - + 2 * torch.matmul(scale_sqrt_other, scale_sqrt.transpose(-1, -2)), + dim=(-2, -1) + ) + + return mean_part, scale_part def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: mean = policy_params["loc"] proj_mean = proj_policy_params["loc"] scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]]) proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]]) - mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec) - w2 = mean_part + cov_part + + mean_part, scale_part = gaussian_wasserstein_commutative( + (mean, scale_or_sqrt), + (proj_mean, proj_scale_or_sqrt), + self.scale_prec + ) + w2 = mean_part + scale_part return w2.mean() * self.trust_region_coeff - def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: - diff = mean - old_mean - norm = torch.sqrt(mean_part) - return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean) - - def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: + def _scale_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor: + """Project scale parameters using multiplicative update.""" if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case - diff = scale_or_sqrt - old_scale_or_sqrt - norm = torch.sqrt(cov_part) - return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt) + return self._diagonal_scale_projection(scale_or_sqrt, old_scale_or_sqrt, scale_part) else: # Full covariance case - diff = scale_or_sqrt - old_scale_or_sqrt - norm = torch.norm(diff, dim=(-2, -1), keepdim=True) - return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt) \ No newline at end of file + return self._full_cov_scale_projection(scale_or_sqrt, old_scale_or_sqrt, scale_part) + + def _diagonal_scale_projection(self, scale: torch.Tensor, old_scale: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor: + cov_mask = scale_part > self.cov_bound + + batch_shape = scale.shape[:-1] + eta = torch.ones(batch_shape, dtype=scale.dtype, device=scale.device) + eta = torch.where(cov_mask, + torch.sqrt(scale_part / self.cov_bound) - 1., + eta) + eta = torch.maximum(-eta, eta) + + new_scale = (scale + eta[..., None] * old_scale) / \ + (1. + eta + 1e-16)[..., None] + mask_matrix = cov_mask[..., None].to(scale.dtype) + return torch.where(mask_matrix, new_scale, scale) + + def _full_cov_scale_projection(self, scale_sqrt: torch.Tensor, old_scale_sqrt: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor: + cov_mask = scale_part > self.cov_bound + + batch_shape = scale_sqrt.shape[:-2] + eta = torch.ones(batch_shape, dtype=scale_sqrt.dtype, device=scale_sqrt.device) + eta = torch.where(cov_mask, + torch.sqrt(scale_part / self.cov_bound) - 1., + eta) + eta = torch.maximum(-eta, eta) + + new_scale = (scale_sqrt + torch.einsum('...,...ij->...ij', eta, old_scale_sqrt)) / \ + (1. + eta + 1e-16)[..., None, None] + mask_matrix = cov_mask[..., None, None].to(scale_sqrt.dtype) + return torch.where(mask_matrix, new_scale, scale_sqrt) \ No newline at end of file