Fixing issues with projections

This commit is contained in:
Dominik Moritz Roth 2024-10-21 15:23:17 +02:00
parent 71cb8593d9
commit 651ef1522f
5 changed files with 199 additions and 108 deletions

View File

@ -1,16 +1,71 @@
from abc import ABC, abstractmethod
import torch
from typing import Dict
from torch import nn
from typing import Dict, List
class BaseProjection(ABC, torch.nn.Module):
def __init__(self, in_keys: list[str], out_keys: list[str]):
class BaseProjection(nn.Module, ABC):
def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__()
self._validate_in_keys(in_keys)
self._validate_out_keys(out_keys)
self.in_keys = in_keys
self.out_keys = out_keys
self.trust_region_coeff = trust_region_coeff
self.mean_bound = mean_bound
self.cov_bound = cov_bound
self.full_cov = "scale_tril" in in_keys
self.contextual_std = contextual_std
def _validate_in_keys(self, keys: List[str]):
valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"}
if not set(keys).issubset(valid_keys):
raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}")
if "loc" not in keys or "old_loc" not in keys:
raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys")
if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys):
raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'")
def _validate_out_keys(self, keys: List[str]):
valid_keys = {"loc", "scale", "scale_tril"}
if not set(keys).issubset(valid_keys):
raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}")
if "loc" not in keys:
raise ValueError("'loc' must be included in out_keys")
if "scale" not in keys and "scale_tril" not in keys:
raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys")
@abstractmethod
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pass
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.project(policy_params, old_policy_params)
@abstractmethod
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
pass
def forward(self, tensordict):
policy_params = {}
old_policy_params = {}
for key in self.in_keys:
if key not in tensordict:
raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}")
if key.startswith("old_"):
old_policy_params[key[4:]] = tensordict[key]
else:
policy_params[key] = tensordict[key]
projected_params = self.project(policy_params, old_policy_params)
return projected_params
def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor:
if not self.full_cov:
return torch.diag_embed(params["scale"].pow(2))
else:
return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2))
def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor:
if not self.full_cov:
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
else:
return torch.linalg.cholesky(cov)

View File

@ -1,33 +1,34 @@
import torch
from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict
class FrobeniusProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
cov = self._calc_covariance(policy_params)
old_cov = self._calc_covariance(old_policy_params)
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part)
proj_chol = torch.linalg.cholesky(proj_cov)
return {"loc": proj_mean, "scale_tril": proj_chol}
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
cov = self._calc_covariance(policy_params)
proj_cov = self._calc_covariance(proj_policy_params)
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))

View File

@ -3,8 +3,8 @@ from .base_projection import BaseProjection
from typing import Dict
class IdentityProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return policy_params

View File

@ -2,6 +2,7 @@ import torch
import cpp_projection
import numpy as np
from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict, Tuple, Any
MAX_EVAL = 1000
@ -10,57 +11,65 @@ def get_numpy(tensor):
return tensor.detach().cpu().numpy()
class KLProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.is_diag = is_diag
self.contextual_std = contextual_std
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, std = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
if not self.contextual_std:
std = std[:1]
old_std = old_std[:1]
scale_or_tril = scale_or_tril[:1]
old_scale_or_tril = old_scale_or_tril[:1]
cov_part = cov_part[:1]
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_std = self._cov_projection(std, old_std, cov_part)
proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
if not self.contextual_std:
proj_std = proj_std.expand(mean.shape[0], -1, -1)
proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
return {"loc": proj_mean, "scale_tril": proj_std}
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, std = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]]
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
return kl.mean() * self.trust_region_coeff
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
mean, std = p
mean_other, std_other = q
mean, scale_or_tril = p
mean_other, scale_or_tril_other = q
k = mean.shape[-1]
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
det_term = self._log_determinant(std)
det_term_other = self._log_determinant(std_other)
det_term = self._log_determinant(scale_or_tril)
det_term_other = self._log_determinant(scale_or_tril_other)
if self.full_cov:
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False))
else:
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
return maha_part, cov_part
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
diff = x - y
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
if self.full_cov:
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1)
else:
return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
if self.full_cov:
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1)
else:
return 2 * torch.log(scale_or_tril).sum(-1)
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
return torch.sum(x.pow(2), dim=(-2, -1))
@ -68,49 +77,45 @@ class KLProjection(BaseProjection):
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
cov = torch.matmul(std, std.transpose(-1, -2))
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
if self.is_diag:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
try:
if mask.any():
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
old_cov.diagonal(dim1=-2, dim2=-1),
self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
mask &= ~is_invalid
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
except Exception as e:
proj_std = old_std
def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
if self.full_cov:
cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
else:
try:
cov = scale_or_tril.pow(2)
old_cov = old_scale_or_tril.pow(2)
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
proj_scale_or_tril = torch.zeros_like(scale_or_tril)
proj_scale_or_tril[~mask] = scale_or_tril[~mask]
try:
if mask.any():
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
if self.full_cov:
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound)
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
mask &= ~is_invalid
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
failed_mask = failed_mask.bool()
if failed_mask.any():
proj_std[failed_mask] = old_std[failed_mask]
proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask]
else:
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
mask &= ~is_invalid
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
except Exception as e:
import logging
logging.error('Projection failed, taking old cholesky for projection.')
print("Projection failed, taking old cholesky for projection.")
proj_std = old_std
logging.error('Projection failed, taking old scale_or_tril for projection.')
print("Projection failed, taking old scale_or_tril for projection.")
proj_scale_or_tril = old_scale_or_tril
raise e
return proj_std
return proj_scale_or_tril
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):

View File

@ -1,56 +1,86 @@
import torch
from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict, Tuple
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
"""
'Converts' scale_tril to scale_sqrt.
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
"""
return scale_tril
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
mean, sqrt = p
mean_other, sqrt_other = q
mean, scale_or_sqrt = p
mean_other, scale_or_sqrt_other = q
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
cov = scale_or_sqrt.pow(2)
cov_other = scale_or_sqrt_other.pow(2)
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
c = sqrt_inv_other @ cov @ sqrt_inv_other
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
sqrt_inv_other = 1 / scale_or_sqrt_other
c = sqrt_inv_other.pow(2) * cov
cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1)
else:
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1)
else: # Full covariance case
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity)
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2)
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
else:
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
return mean_part, cov_part
class WassersteinProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]])
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec)
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
return {"loc": proj_mean, "scale_tril": proj_sqrt}
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean
norm = torch.norm(diff, dim=-1, keepdim=True)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
norm = torch.sqrt(mean_part)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
diff = sqrt - old_sqrt
def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = torch.sqrt(cov_part)
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
else: # Full covariance case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt)