diff --git a/fancy_rl/projections/base_projection.py b/fancy_rl/projections/base_projection.py index 31f1fa0..3ceb5d1 100644 --- a/fancy_rl/projections/base_projection.py +++ b/fancy_rl/projections/base_projection.py @@ -1,16 +1,71 @@ from abc import ABC, abstractmethod import torch -from typing import Dict +from torch import nn +from typing import Dict, List -class BaseProjection(ABC, torch.nn.Module): - def __init__(self, in_keys: list[str], out_keys: list[str]): +class BaseProjection(nn.Module, ABC): + def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): super().__init__() + self._validate_in_keys(in_keys) + self._validate_out_keys(out_keys) self.in_keys = in_keys self.out_keys = out_keys + self.trust_region_coeff = trust_region_coeff + self.mean_bound = mean_bound + self.cov_bound = cov_bound + self.full_cov = "scale_tril" in in_keys + self.contextual_std = contextual_std + + def _validate_in_keys(self, keys: List[str]): + valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"} + if not set(keys).issubset(valid_keys): + raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}") + if "loc" not in keys or "old_loc" not in keys: + raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys") + if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys): + raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'") + + def _validate_out_keys(self, keys: List[str]): + valid_keys = {"loc", "scale", "scale_tril"} + if not set(keys).issubset(valid_keys): + raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}") + if "loc" not in keys: + raise ValueError("'loc' must be included in out_keys") + if "scale" not in keys and "scale_tril" not in keys: + raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys") @abstractmethod def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: pass - def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - return self.project(policy_params, old_policy_params) \ No newline at end of file + @abstractmethod + def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: + pass + + def forward(self, tensordict): + policy_params = {} + old_policy_params = {} + + for key in self.in_keys: + if key not in tensordict: + raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}") + + if key.startswith("old_"): + old_policy_params[key[4:]] = tensordict[key] + else: + policy_params[key] = tensordict[key] + + projected_params = self.project(policy_params, old_policy_params) + return projected_params + + def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor: + if not self.full_cov: + return torch.diag_embed(params["scale"].pow(2)) + else: + return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2)) + + def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor: + if not self.full_cov: + return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1)) + else: + return torch.linalg.cholesky(cov) \ No newline at end of file diff --git a/fancy_rl/projections/frobenius_projection.py b/fancy_rl/projections/frobenius_projection.py index 9ad3480..2a92c04 100644 --- a/fancy_rl/projections/frobenius_projection.py +++ b/fancy_rl/projections/frobenius_projection.py @@ -1,33 +1,34 @@ import torch from .base_projection import BaseProjection +from tensordict.nn import TensorDictModule from typing import Dict class FrobeniusProjection(BaseProjection): - def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): - super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) self.scale_prec = scale_prec def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - mean, chol = policy_params["loc"], policy_params["scale_tril"] - old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"] + mean = policy_params["loc"] + old_mean = old_policy_params["loc"] - cov = torch.matmul(chol, chol.transpose(-1, -2)) - old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2)) + cov = self._calc_covariance(policy_params) + old_cov = self._calc_covariance(old_policy_params) mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_cov = self._cov_projection(cov, old_cov, cov_part) - proj_chol = torch.linalg.cholesky(proj_cov) - return {"loc": proj_mean, "scale_tril": proj_chol} + scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov) + return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril} def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: - mean, chol = policy_params["loc"], policy_params["scale_tril"] - proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"] + mean = policy_params["loc"] + proj_mean = proj_policy_params["loc"] - cov = torch.matmul(chol, chol.transpose(-1, -2)) - proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2)) + cov = self._calc_covariance(policy_params) + proj_cov = self._calc_covariance(proj_policy_params) mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1) cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1)) diff --git a/fancy_rl/projections/identity_projection.py b/fancy_rl/projections/identity_projection.py index adb8af3..c11c3e9 100644 --- a/fancy_rl/projections/identity_projection.py +++ b/fancy_rl/projections/identity_projection.py @@ -3,8 +3,8 @@ from .base_projection import BaseProjection from typing import Dict class IdentityProjection(BaseProjection): - def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01): - super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return policy_params diff --git a/fancy_rl/projections/kl_projection.py b/fancy_rl/projections/kl_projection.py index 4dddecd..eb65fd8 100644 --- a/fancy_rl/projections/kl_projection.py +++ b/fancy_rl/projections/kl_projection.py @@ -2,6 +2,7 @@ import torch import cpp_projection import numpy as np from .base_projection import BaseProjection +from tensordict.nn import TensorDictModule from typing import Dict, Tuple, Any MAX_EVAL = 1000 @@ -10,57 +11,65 @@ def get_numpy(tensor): return tensor.detach().cpu().numpy() class KLProjection(BaseProjection): - def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True): - super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) - self.is_diag = is_diag - self.contextual_std = contextual_std + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - mean, std = policy_params["loc"], policy_params["scale_tril"] - old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"] + mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] + old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]] - mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std)) + mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril)) if not self.contextual_std: - std = std[:1] - old_std = old_std[:1] + scale_or_tril = scale_or_tril[:1] + old_scale_or_tril = old_scale_or_tril[:1] cov_part = cov_part[:1] proj_mean = self._mean_projection(mean, old_mean, mean_part) - proj_std = self._cov_projection(std, old_std, cov_part) + proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part) if not self.contextual_std: - proj_std = proj_std.expand(mean.shape[0], -1, -1) + proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:]) - return {"loc": proj_mean, "scale_tril": proj_std} + return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril} def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: - mean, std = policy_params["loc"], policy_params["scale_tril"] - proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"] - kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std))) + mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] + proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]] + kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) return kl.mean() * self.trust_region_coeff def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - mean, std = p - mean_other, std_other = q + mean, scale_or_tril = p + mean_other, scale_or_tril_other = q k = mean.shape[-1] - maha_part = 0.5 * self._maha(mean, mean_other, std_other) + maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other) - det_term = self._log_determinant(std) - det_term_other = self._log_determinant(std_other) + det_term = self._log_determinant(scale_or_tril) + det_term_other = self._log_determinant(scale_or_tril_other) + + if self.full_cov: + trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False)) + else: + trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1) - trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False)) cov_part = 0.5 * (trace_part - k + det_term_other - det_term) return maha_part, cov_part - def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor: diff = x - y - return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1) + if self.full_cov: + return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1) + else: + return torch.sum(torch.square(diff / scale_or_tril), dim=-1) - def _log_determinant(self, std: torch.Tensor) -> torch.Tensor: - return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1) + def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor: + if self.full_cov: + return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1) + else: + return 2 * torch.log(scale_or_tril).sum(-1) def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: return torch.sum(x.pow(2), dim=(-2, -1)) @@ -68,49 +77,45 @@ class KLProjection(BaseProjection): def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1) - def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: - cov = torch.matmul(std, std.transpose(-1, -2)) - old_cov = torch.matmul(old_std, old_std.transpose(-1, -2)) - - if self.is_diag: - mask = cov_part > self.cov_bound - proj_std = torch.zeros_like(std) - proj_std[~mask] = std[~mask] - try: - if mask.any(): - proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1), - old_cov.diagonal(dim1=-2, dim2=-1), - self.cov_bound) - is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask - if is_invalid.any(): - proj_std[is_invalid] = old_std[is_invalid] - mask &= ~is_invalid - proj_std[mask] = proj_cov[mask].sqrt().diag_embed() - except Exception as e: - proj_std = old_std + def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: + if self.full_cov: + cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2)) + old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2)) else: - try: - mask = cov_part > self.cov_bound - proj_std = torch.zeros_like(std) - proj_std[~mask] = std[~mask] - if mask.any(): - proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound) + cov = scale_or_tril.pow(2) + old_cov = old_scale_or_tril.pow(2) + + mask = cov_part > self.cov_bound + proj_scale_or_tril = torch.zeros_like(scale_or_tril) + proj_scale_or_tril[~mask] = scale_or_tril[~mask] + + try: + if mask.any(): + if self.full_cov: + proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound) is_invalid = proj_cov.mean([-2, -1]).isnan() & mask if is_invalid.any(): - proj_std[is_invalid] = old_std[is_invalid] + proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid] mask &= ~is_invalid - proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) + proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) failed_mask = failed_mask.bool() if failed_mask.any(): - proj_std[failed_mask] = old_std[failed_mask] - except Exception as e: - import logging - logging.error('Projection failed, taking old cholesky for projection.') - print("Projection failed, taking old cholesky for projection.") - proj_std = old_std - raise e + proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask] + else: + proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound) + is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask + if is_invalid.any(): + proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid] + mask &= ~is_invalid + proj_scale_or_tril[mask] = proj_cov[mask].sqrt() + except Exception as e: + import logging + logging.error('Projection failed, taking old scale_or_tril for projection.') + print("Projection failed, taking old scale_or_tril for projection.") + proj_scale_or_tril = old_scale_or_tril + raise e - return proj_std + return proj_scale_or_tril class KLProjectionGradFunctionCovOnly(torch.autograd.Function): diff --git a/fancy_rl/projections/wasserstein_projection.py b/fancy_rl/projections/wasserstein_projection.py index fb58058..d19dc84 100644 --- a/fancy_rl/projections/wasserstein_projection.py +++ b/fancy_rl/projections/wasserstein_projection.py @@ -1,56 +1,86 @@ import torch from .base_projection import BaseProjection +from tensordict.nn import TensorDictModule from typing import Dict, Tuple +def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor: + """ + 'Converts' scale_tril to scale_sqrt. + + For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition. + But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root. + """ + return scale_tril + def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]: - mean, sqrt = p - mean_other, sqrt_other = q + mean, scale_or_sqrt = p + mean_other, scale_or_sqrt_other = q mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) - cov = torch.matmul(sqrt, sqrt.transpose(-1, -2)) - cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2)) - - if scale_prec: - identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device) - sqrt_inv_other = torch.linalg.solve(sqrt_other, identity) - c = sqrt_inv_other @ cov @ sqrt_inv_other - cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt) - else: - cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt) + if scale_or_sqrt.dim() == mean.dim(): # Diagonal case + cov = scale_or_sqrt.pow(2) + cov_other = scale_or_sqrt_other.pow(2) + if scale_prec: + identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) + sqrt_inv_other = 1 / scale_or_sqrt_other + c = sqrt_inv_other.pow(2) * cov + cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1) + else: + cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1) + else: # Full covariance case + # Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition + cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2)) + cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2)) + if scale_prec: + identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) + sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity) + c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2) + cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt) + else: + cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt) return mean_part, cov_part class WassersteinProjection(BaseProjection): - def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): - super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) self.scale_prec = scale_prec def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - mean, sqrt = policy_params["loc"], policy_params["scale_tril"] - old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"] + mean = policy_params["loc"] + old_mean = old_policy_params["loc"] + scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]]) + old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]]) - mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec) + mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec) proj_mean = self._mean_projection(mean, old_mean, mean_part) - proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part) + proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part) - return {"loc": proj_mean, "scale_tril": proj_sqrt} + return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt} def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: - mean, sqrt = policy_params["loc"], policy_params["scale_tril"] - proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"] - mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec) + mean = policy_params["loc"] + proj_mean = proj_policy_params["loc"] + scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]]) + proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]]) + mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec) w2 = mean_part + cov_part return w2.mean() * self.trust_region_coeff def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: diff = mean - old_mean - norm = torch.norm(diff, dim=-1, keepdim=True) - return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean) + norm = torch.sqrt(mean_part) + return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean) - def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: - diff = sqrt - old_sqrt - norm = torch.norm(diff, dim=(-2, -1), keepdim=True) - return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt) \ No newline at end of file + def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: + if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case + diff = scale_or_sqrt - old_scale_or_sqrt + norm = torch.sqrt(cov_part) + return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt) + else: # Full covariance case + diff = scale_or_sqrt - old_scale_or_sqrt + norm = torch.norm(diff, dim=(-2, -1), keepdim=True) + return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt) \ No newline at end of file