From 5fc4b30ea83937336385dfc4eac797fd28652e44 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 28 Aug 2024 11:31:42 +0200 Subject: [PATCH] New projection impls --- fancy_rl/projections/__init__.py | 26 ++- fancy_rl/projections/base_projection.py | 16 ++ fancy_rl/projections/frob_projection.py | 67 ++++++ fancy_rl/projections/identity_projection.py | 13 ++ fancy_rl/projections/kl_projection.py | 199 ++++++++++++++++++ .../projections/wasserstein_projection.py | 56 +++++ fancy_rl/projections_legacy/__init__.py | 6 + .../base_projection_layer.py | 0 .../frob_projection_layer.py | 0 .../identity_projection_layer.py | 0 .../kl_projection_layer.py | 0 .../w2_projection_layer.py | 0 fancy_rl/projections_new/__init__.py | 13 -- fancy_rl/projections_new/base_projection.py | 51 ----- fancy_rl/projections_new/frob_projection.py | 22 -- .../projections_new/identity_projection.py | 11 - fancy_rl/projections_new/kl_projection.py | 33 --- fancy_rl/projections_new/w2_projection.py | 73 ------- 18 files changed, 377 insertions(+), 209 deletions(-) create mode 100644 fancy_rl/projections/base_projection.py create mode 100644 fancy_rl/projections/frob_projection.py create mode 100644 fancy_rl/projections/identity_projection.py create mode 100644 fancy_rl/projections/kl_projection.py create mode 100644 fancy_rl/projections/wasserstein_projection.py create mode 100644 fancy_rl/projections_legacy/__init__.py rename fancy_rl/{projections => projections_legacy}/base_projection_layer.py (100%) rename fancy_rl/{projections => projections_legacy}/frob_projection_layer.py (100%) rename fancy_rl/{projections => projections_legacy}/identity_projection_layer.py (100%) rename fancy_rl/{projections => projections_legacy}/kl_projection_layer.py (100%) rename fancy_rl/{projections => projections_legacy}/w2_projection_layer.py (100%) delete mode 100644 fancy_rl/projections_new/__init__.py delete mode 100644 fancy_rl/projections_new/base_projection.py delete mode 100644 fancy_rl/projections_new/frob_projection.py delete mode 100644 fancy_rl/projections_new/identity_projection.py delete mode 100644 fancy_rl/projections_new/kl_projection.py delete mode 100644 fancy_rl/projections_new/w2_projection.py diff --git a/fancy_rl/projections/__init__.py b/fancy_rl/projections/__init__.py index dc28ef1..fd075ba 100644 --- a/fancy_rl/projections/__init__.py +++ b/fancy_rl/projections/__init__.py @@ -1,6 +1,20 @@ -try: - import cpp_projection -except ModuleNotFoundError: - from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer -else: - from .kl_projection_layer import KLProjectionLayer \ No newline at end of file +from .base_projection import BaseProjection +from .identity_projection import IdentityProjection +from .kl_projection import KLProjection +from .wasserstein_projection import WassersteinProjection +from .frobenius_projection import FrobeniusProjection + +def get_projection(projection_name: str): + projections = { + "identity_projection": IdentityProjection, + "kl_projection": KLProjection, + "wasserstein_projection": WassersteinProjection, + "frobenius_projection": FrobeniusProjection, + } + + projection = projections.get(projection_name.lower()) + if projection is None: + raise ValueError(f"Unknown projection: {projection_name}") + return projection + +__all__ = ["BaseProjection", "IdentityProjection", "KLProjection", "WassersteinProjection", "FrobeniusProjection", "get_projection"] \ No newline at end of file diff --git a/fancy_rl/projections/base_projection.py b/fancy_rl/projections/base_projection.py new file mode 100644 index 0000000..31f1fa0 --- /dev/null +++ b/fancy_rl/projections/base_projection.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +import torch +from typing import Dict + +class BaseProjection(ABC, torch.nn.Module): + def __init__(self, in_keys: list[str], out_keys: list[str]): + super().__init__() + self.in_keys = in_keys + self.out_keys = out_keys + + @abstractmethod + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + pass + + def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return self.project(policy_params, old_policy_params) \ No newline at end of file diff --git a/fancy_rl/projections/frob_projection.py b/fancy_rl/projections/frob_projection.py new file mode 100644 index 0000000..9ad3480 --- /dev/null +++ b/fancy_rl/projections/frob_projection.py @@ -0,0 +1,67 @@ +import torch +from .base_projection import BaseProjection +from typing import Dict + +class FrobeniusProjection(BaseProjection): + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) + self.scale_prec = scale_prec + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + mean, chol = policy_params["loc"], policy_params["scale_tril"] + old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"] + + cov = torch.matmul(chol, chol.transpose(-1, -2)) + old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2)) + + mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) + + proj_mean = self._mean_projection(mean, old_mean, mean_part) + proj_cov = self._cov_projection(cov, old_cov, cov_part) + + proj_chol = torch.linalg.cholesky(proj_cov) + return {"loc": proj_mean, "scale_tril": proj_chol} + + def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: + mean, chol = policy_params["loc"], policy_params["scale_tril"] + proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"] + + cov = torch.matmul(chol, chol.transpose(-1, -2)) + proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2)) + + mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1) + cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1)) + + return (mean_diff + cov_diff).mean() * self.trust_region_coeff + + def _gaussian_frobenius(self, p, q): + mean, cov = p + old_mean, old_cov = q + + if self.scale_prec: + prec_old = torch.inverse(old_cov) + mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1) + cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1] + else: + mean_part = torch.sum(torch.square(mean - old_mean), dim=-1) + cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1)) + + return mean_part, cov_part + + def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: + diff = mean - old_mean + norm = torch.sqrt(mean_part) + return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean) + + def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: + batch_shape = cov.shape[:-2] + cov_mask = cov_part > self.cov_bound + + eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device) + eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1. + eta = torch.max(-eta, eta) + + new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] + proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov) + + return proj_cov \ No newline at end of file diff --git a/fancy_rl/projections/identity_projection.py b/fancy_rl/projections/identity_projection.py new file mode 100644 index 0000000..adb8af3 --- /dev/null +++ b/fancy_rl/projections/identity_projection.py @@ -0,0 +1,13 @@ +import torch +from .base_projection import BaseProjection +from typing import Dict + +class IdentityProjection(BaseProjection): + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return policy_params + + def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: + return torch.tensor(0.0, device=next(iter(policy_params.values())).device) \ No newline at end of file diff --git a/fancy_rl/projections/kl_projection.py b/fancy_rl/projections/kl_projection.py new file mode 100644 index 0000000..4dddecd --- /dev/null +++ b/fancy_rl/projections/kl_projection.py @@ -0,0 +1,199 @@ +import torch +import cpp_projection +import numpy as np +from .base_projection import BaseProjection +from typing import Dict, Tuple, Any + +MAX_EVAL = 1000 + +def get_numpy(tensor): + return tensor.detach().cpu().numpy() + +class KLProjection(BaseProjection): + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) + self.is_diag = is_diag + self.contextual_std = contextual_std + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + mean, std = policy_params["loc"], policy_params["scale_tril"] + old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"] + + mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std)) + + if not self.contextual_std: + std = std[:1] + old_std = old_std[:1] + cov_part = cov_part[:1] + + proj_mean = self._mean_projection(mean, old_mean, mean_part) + proj_std = self._cov_projection(std, old_std, cov_part) + + if not self.contextual_std: + proj_std = proj_std.expand(mean.shape[0], -1, -1) + + return {"loc": proj_mean, "scale_tril": proj_std} + + def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: + mean, std = policy_params["loc"], policy_params["scale_tril"] + proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"] + kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std))) + return kl.mean() * self.trust_region_coeff + + def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + mean, std = p + mean_other, std_other = q + k = mean.shape[-1] + + maha_part = 0.5 * self._maha(mean, mean_other, std_other) + + det_term = self._log_determinant(std) + det_term_other = self._log_determinant(std_other) + + trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False)) + cov_part = 0.5 * (trace_part - k + det_term_other - det_term) + + return maha_part, cov_part + + def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + diff = x - y + return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1) + + def _log_determinant(self, std: torch.Tensor) -> torch.Tensor: + return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1) + + def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: + return torch.sum(x.pow(2), dim=(-2, -1)) + + def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: + return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1) + + def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: + cov = torch.matmul(std, std.transpose(-1, -2)) + old_cov = torch.matmul(old_std, old_std.transpose(-1, -2)) + + if self.is_diag: + mask = cov_part > self.cov_bound + proj_std = torch.zeros_like(std) + proj_std[~mask] = std[~mask] + try: + if mask.any(): + proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1), + old_cov.diagonal(dim1=-2, dim2=-1), + self.cov_bound) + is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask + if is_invalid.any(): + proj_std[is_invalid] = old_std[is_invalid] + mask &= ~is_invalid + proj_std[mask] = proj_cov[mask].sqrt().diag_embed() + except Exception as e: + proj_std = old_std + else: + try: + mask = cov_part > self.cov_bound + proj_std = torch.zeros_like(std) + proj_std[~mask] = std[~mask] + if mask.any(): + proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound) + is_invalid = proj_cov.mean([-2, -1]).isnan() & mask + if is_invalid.any(): + proj_std[is_invalid] = old_std[is_invalid] + mask &= ~is_invalid + proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) + failed_mask = failed_mask.bool() + if failed_mask.any(): + proj_std[failed_mask] = old_std[failed_mask] + except Exception as e: + import logging + logging.error('Projection failed, taking old cholesky for projection.') + print("Projection failed, taking old cholesky for projection.") + proj_std = old_std + raise e + + return proj_std + + +class KLProjectionGradFunctionCovOnly(torch.autograd.Function): + projection_op = None + + @staticmethod + def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL): + if not KLProjectionGradFunctionCovOnly.projection_op: + KLProjectionGradFunctionCovOnly.projection_op = \ + cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=max_eval) + return KLProjectionGradFunctionCovOnly.projection_op + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: + cov, chol, old_chol, eps_cov = args + + batch_shape = cov.shape[0] + dim = cov.shape[-1] + + cov_np = get_numpy(cov) + chol_np = get_numpy(chol) + old_chol_np = get_numpy(old_chol) + eps = get_numpy(eps_cov) * np.ones(batch_shape) + + p_op = KLProjectionGradFunctionCovOnly.get_projection_op(batch_shape, dim) + ctx.proj = p_op + + proj_std = p_op.forward(eps, old_chol_np, chol_np, cov_np) + + return cov.new(proj_std) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + projection_op = ctx.proj + d_cov, = grad_outputs + + d_cov_np = get_numpy(d_cov) + d_cov_np = np.atleast_2d(d_cov_np) + + df_stds = projection_op.backward(d_cov_np) + df_stds = np.atleast_2d(df_stds) + + df_stds = d_cov.new(df_stds) + + return df_stds, None, None, None + + +class KLProjectionGradFunctionDiagCovOnly(torch.autograd.Function): + projection_op = None + + @staticmethod + def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL): + if not KLProjectionGradFunctionDiagCovOnly.projection_op: + KLProjectionGradFunctionDiagCovOnly.projection_op = \ + cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval) + return KLProjectionGradFunctionDiagCovOnly.projection_op + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: + cov, old_cov, eps_cov = args + + batch_shape = cov.shape[0] + dim = cov.shape[-1] + + cov_np = get_numpy(cov) + old_cov_np = get_numpy(old_cov) + eps = get_numpy(eps_cov) * np.ones(batch_shape) + + p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim) + ctx.proj = p_op + + proj_std = p_op.forward(eps, old_cov_np, cov_np) + + return cov.new(proj_std) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + projection_op = ctx.proj + d_std, = grad_outputs + + d_cov_np = get_numpy(d_std) + d_cov_np = np.atleast_2d(d_cov_np) + df_stds = projection_op.backward(d_cov_np) + df_stds = np.atleast_2d(df_stds) + + return d_std.new(df_stds), None, None \ No newline at end of file diff --git a/fancy_rl/projections/wasserstein_projection.py b/fancy_rl/projections/wasserstein_projection.py new file mode 100644 index 0000000..fb58058 --- /dev/null +++ b/fancy_rl/projections/wasserstein_projection.py @@ -0,0 +1,56 @@ +import torch +from .base_projection import BaseProjection +from typing import Dict, Tuple + +def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor], + q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]: + mean, sqrt = p + mean_other, sqrt_other = q + + mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) + + cov = torch.matmul(sqrt, sqrt.transpose(-1, -2)) + cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2)) + + if scale_prec: + identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device) + sqrt_inv_other = torch.linalg.solve(sqrt_other, identity) + c = sqrt_inv_other @ cov @ sqrt_inv_other + cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt) + else: + cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt) + + return mean_part, cov_part + +class WassersteinProjection(BaseProjection): + def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): + super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) + self.scale_prec = scale_prec + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + mean, sqrt = policy_params["loc"], policy_params["scale_tril"] + old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"] + + mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec) + + proj_mean = self._mean_projection(mean, old_mean, mean_part) + proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part) + + return {"loc": proj_mean, "scale_tril": proj_sqrt} + + def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: + mean, sqrt = policy_params["loc"], policy_params["scale_tril"] + proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"] + mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec) + w2 = mean_part + cov_part + return w2.mean() * self.trust_region_coeff + + def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: + diff = mean - old_mean + norm = torch.norm(diff, dim=-1, keepdim=True) + return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean) + + def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: + diff = sqrt - old_sqrt + norm = torch.norm(diff, dim=(-2, -1), keepdim=True) + return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt) \ No newline at end of file diff --git a/fancy_rl/projections_legacy/__init__.py b/fancy_rl/projections_legacy/__init__.py new file mode 100644 index 0000000..dc28ef1 --- /dev/null +++ b/fancy_rl/projections_legacy/__init__.py @@ -0,0 +1,6 @@ +try: + import cpp_projection +except ModuleNotFoundError: + from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer +else: + from .kl_projection_layer import KLProjectionLayer \ No newline at end of file diff --git a/fancy_rl/projections/base_projection_layer.py b/fancy_rl/projections_legacy/base_projection_layer.py similarity index 100% rename from fancy_rl/projections/base_projection_layer.py rename to fancy_rl/projections_legacy/base_projection_layer.py diff --git a/fancy_rl/projections/frob_projection_layer.py b/fancy_rl/projections_legacy/frob_projection_layer.py similarity index 100% rename from fancy_rl/projections/frob_projection_layer.py rename to fancy_rl/projections_legacy/frob_projection_layer.py diff --git a/fancy_rl/projections/identity_projection_layer.py b/fancy_rl/projections_legacy/identity_projection_layer.py similarity index 100% rename from fancy_rl/projections/identity_projection_layer.py rename to fancy_rl/projections_legacy/identity_projection_layer.py diff --git a/fancy_rl/projections/kl_projection_layer.py b/fancy_rl/projections_legacy/kl_projection_layer.py similarity index 100% rename from fancy_rl/projections/kl_projection_layer.py rename to fancy_rl/projections_legacy/kl_projection_layer.py diff --git a/fancy_rl/projections/w2_projection_layer.py b/fancy_rl/projections_legacy/w2_projection_layer.py similarity index 100% rename from fancy_rl/projections/w2_projection_layer.py rename to fancy_rl/projections_legacy/w2_projection_layer.py diff --git a/fancy_rl/projections_new/__init__.py b/fancy_rl/projections_new/__init__.py deleted file mode 100644 index c95631d..0000000 --- a/fancy_rl/projections_new/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base_projection import BaseProjection -from .kl_projection import KLProjection -from .w2_projection import W2Projection -from .frob_projection import FrobProjection -from .identity_projection import IdentityProjection - -__all__ = [ - "BaseProjection", - "KLProjection", - "W2Projection", - "FrobProjection", - "IdentityProjection" -] \ No newline at end of file diff --git a/fancy_rl/projections_new/base_projection.py b/fancy_rl/projections_new/base_projection.py deleted file mode 100644 index 2fe512a..0000000 --- a/fancy_rl/projections_new/base_projection.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from torchrl.modules import TensorDictModule -from typing import List, Dict, Any - -class BaseProjection(TensorDictModule): - def __init__( - self, - in_keys: List[str], - out_keys: List[str], - ): - super().__init__(in_keys=in_keys, out_keys=out_keys) - - def forward(self, tensordict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - mean, std = self.in_keys - projected_mean, projected_std = self.out_keys - - old_mean = tensordict[mean] - old_std = tensordict[std] - - new_mean = tensordict.get(projected_mean, old_mean) - new_std = tensordict.get(projected_std, old_std) - - projected_params = self.project( - {"mean": new_mean, "std": new_std}, - {"mean": old_mean, "std": old_std} - ) - - tensordict[projected_mean] = projected_params["mean"] - tensordict[projected_std] = projected_params["std"] - - return tensordict - - def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - raise NotImplementedError("Subclasses must implement the project method") - - @classmethod - def make(cls, projection_type: str, **kwargs: Any) -> 'BaseProjection': - if projection_type == "kl": - from .kl_projection import KLProjection - return KLProjection(**kwargs) - elif projection_type == "w2": - from .w2_projection import W2Projection - return W2Projection(**kwargs) - elif projection_type == "frob": - from .frob_projection import FrobProjection - return FrobProjection(**kwargs) - elif projection_type == "identity": - from .identity_projection import IdentityProjection - return IdentityProjection(**kwargs) - else: - raise ValueError(f"Unknown projection type: {projection_type}") \ No newline at end of file diff --git a/fancy_rl/projections_new/frob_projection.py b/fancy_rl/projections_new/frob_projection.py deleted file mode 100644 index ff21293..0000000 --- a/fancy_rl/projections_new/frob_projection.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -from .base_projection import BaseProjection -from typing import Dict - -class FrobProjection(BaseProjection): - def __init__(self, in_keys: list[str], out_keys: list[str], epsilon: float = 1e-3): - super().__init__(in_keys=in_keys, out_keys=out_keys) - self.epsilon = epsilon - - def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - projected_params = {} - for key in policy_params.keys(): - old_param = old_policy_params[key] - new_param = policy_params[key] - diff = new_param - old_param - norm = torch.norm(diff) - if norm > self.epsilon: - projected_param = old_param + (self.epsilon / norm) * diff - else: - projected_param = new_param - projected_params[key] = projected_param - return projected_params \ No newline at end of file diff --git a/fancy_rl/projections_new/identity_projection.py b/fancy_rl/projections_new/identity_projection.py deleted file mode 100644 index daad770..0000000 --- a/fancy_rl/projections_new/identity_projection.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch -from .base_projection import BaseProjection -from typing import Dict - -class IdentityProjection(BaseProjection): - def __init__(self, in_keys: list[str], out_keys: list[str]): - super().__init__(in_keys=in_keys, out_keys=out_keys) - - def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - # The identity projection simply returns the new policy parameters without any modification - return policy_params \ No newline at end of file diff --git a/fancy_rl/projections_new/kl_projection.py b/fancy_rl/projections_new/kl_projection.py deleted file mode 100644 index 33724e5..0000000 --- a/fancy_rl/projections_new/kl_projection.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -from typing import Dict, List -from .base_projection import BaseProjection - -class KLProjection(BaseProjection): - def __init__( - self, - in_keys: List[str] = ["mean", "std"], - out_keys: List[str] = ["projected_mean", "projected_std"], - epsilon: float = 0.1 - ): - super().__init__(in_keys=in_keys, out_keys=out_keys) - self.epsilon = epsilon - - def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - new_mean, new_std = policy_params["mean"], policy_params["std"] - old_mean, old_std = old_policy_params["mean"], old_policy_params["std"] - - diff = new_mean - old_mean - std_diff = new_std - old_std - - kl = 0.5 * (torch.sum(torch.square(diff / old_std), dim=-1) + - torch.sum(torch.square(std_diff / old_std), dim=-1) - - new_mean.shape[-1] + - torch.sum(torch.log(new_std / old_std), dim=-1)) - - factor = torch.sqrt(self.epsilon / (kl + 1e-8)) - factor = torch.clamp(factor, max=1.0) - - projected_mean = old_mean + factor.unsqueeze(-1) * diff - projected_std = old_std + factor.unsqueeze(-1) * std_diff - - return {"mean": projected_mean, "std": projected_std} \ No newline at end of file diff --git a/fancy_rl/projections_new/w2_projection.py b/fancy_rl/projections_new/w2_projection.py deleted file mode 100644 index 7b9977c..0000000 --- a/fancy_rl/projections_new/w2_projection.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -from .base_projection import BaseProjection -from typing import Dict, Tuple -from torchrl.modules import TensorDictModule -from torchrl.distributions import TanhNormal, Delta - -class W2Projection(BaseProjection): - def __init__(self, - in_keys: list[str], - out_keys: list[str], - scale_prec: bool = False): - super().__init__(in_keys=in_keys, out_keys=out_keys) - self.scale_prec = scale_prec - - def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - projected_params = {} - for key in policy_params.keys(): - if key.endswith('.loc'): - mean = policy_params[key] - old_mean = old_policy_params[key] - std_key = key.replace('.loc', '.scale') - std = policy_params[std_key] - old_std = old_policy_params[std_key] - - projected_mean, projected_std = self._trust_region_projection( - mean, std, old_mean, old_std - ) - - projected_params[key] = projected_mean - projected_params[std_key] = projected_std - elif not key.endswith('.scale'): - projected_params[key] = policy_params[key] - - return projected_params - - def _trust_region_projection(self, mean: torch.Tensor, std: torch.Tensor, - old_mean: torch.Tensor, old_std: torch.Tensor, - eps: float = 1e-3, eps_cov: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]: - mean_part, cov_part = self._gaussian_wasserstein_commutative(mean, std, old_mean, old_std) - - # Project mean - mean_mask = mean_part > eps - proj_mean = torch.where(mean_mask, - old_mean + (mean - old_mean) * torch.sqrt(eps / mean_part)[..., None], - mean) - - # Project covariance - cov_mask = cov_part > eps_cov - eta = torch.ones_like(cov_part) - eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / eps_cov) - 1. - eta = torch.clamp(eta, -0.9, float('inf')) # Avoid negative values that could lead to invalid standard deviations - - proj_std = (std + eta[..., None] * old_std) / (1. + eta[..., None] + 1e-8) - - return proj_mean, proj_std - - def _gaussian_wasserstein_commutative(self, mean: torch.Tensor, std: torch.Tensor, - old_mean: torch.Tensor, old_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.scale_prec: - # Mahalanobis distance for mean - mean_part = ((mean - old_mean) ** 2 / (old_std ** 2 + 1e-8)).sum(-1) - else: - # Euclidean distance for mean - mean_part = ((mean - old_mean) ** 2).sum(-1) - - # W2 objective for covariance - cov_part = (std ** 2 + old_std ** 2 - 2 * std * old_std).sum(-1) - - return mean_part, cov_part - - @classmethod - def make(cls, in_keys: list[str], out_keys: list[str], **kwargs) -> 'W2Projection': - return cls(in_keys=in_keys, out_keys=out_keys, **kwargs) \ No newline at end of file