From 5f279beccf2317c3629c7b97193a6f3cb25a91b9 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 17 Jul 2024 14:51:59 +0200 Subject: [PATCH] New proj implementations --- fancy_rl/projections_new/__init__.py | 13 ++++ fancy_rl/projections_new/base_projection.py | 51 +++++++++++++ fancy_rl/projections_new/frob_projection.py | 22 ++++++ .../projections_new/identity_projection.py | 11 +++ fancy_rl/projections_new/kl_projection.py | 33 +++++++++ fancy_rl/projections_new/w2_projection.py | 73 +++++++++++++++++++ 6 files changed, 203 insertions(+) create mode 100644 fancy_rl/projections_new/__init__.py create mode 100644 fancy_rl/projections_new/base_projection.py create mode 100644 fancy_rl/projections_new/frob_projection.py create mode 100644 fancy_rl/projections_new/identity_projection.py create mode 100644 fancy_rl/projections_new/kl_projection.py create mode 100644 fancy_rl/projections_new/w2_projection.py diff --git a/fancy_rl/projections_new/__init__.py b/fancy_rl/projections_new/__init__.py new file mode 100644 index 0000000..c95631d --- /dev/null +++ b/fancy_rl/projections_new/__init__.py @@ -0,0 +1,13 @@ +from .base_projection import BaseProjection +from .kl_projection import KLProjection +from .w2_projection import W2Projection +from .frob_projection import FrobProjection +from .identity_projection import IdentityProjection + +__all__ = [ + "BaseProjection", + "KLProjection", + "W2Projection", + "FrobProjection", + "IdentityProjection" +] \ No newline at end of file diff --git a/fancy_rl/projections_new/base_projection.py b/fancy_rl/projections_new/base_projection.py new file mode 100644 index 0000000..2fe512a --- /dev/null +++ b/fancy_rl/projections_new/base_projection.py @@ -0,0 +1,51 @@ +import torch +from torchrl.modules import TensorDictModule +from typing import List, Dict, Any + +class BaseProjection(TensorDictModule): + def __init__( + self, + in_keys: List[str], + out_keys: List[str], + ): + super().__init__(in_keys=in_keys, out_keys=out_keys) + + def forward(self, tensordict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + mean, std = self.in_keys + projected_mean, projected_std = self.out_keys + + old_mean = tensordict[mean] + old_std = tensordict[std] + + new_mean = tensordict.get(projected_mean, old_mean) + new_std = tensordict.get(projected_std, old_std) + + projected_params = self.project( + {"mean": new_mean, "std": new_std}, + {"mean": old_mean, "std": old_std} + ) + + tensordict[projected_mean] = projected_params["mean"] + tensordict[projected_std] = projected_params["std"] + + return tensordict + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + raise NotImplementedError("Subclasses must implement the project method") + + @classmethod + def make(cls, projection_type: str, **kwargs: Any) -> 'BaseProjection': + if projection_type == "kl": + from .kl_projection import KLProjection + return KLProjection(**kwargs) + elif projection_type == "w2": + from .w2_projection import W2Projection + return W2Projection(**kwargs) + elif projection_type == "frob": + from .frob_projection import FrobProjection + return FrobProjection(**kwargs) + elif projection_type == "identity": + from .identity_projection import IdentityProjection + return IdentityProjection(**kwargs) + else: + raise ValueError(f"Unknown projection type: {projection_type}") \ No newline at end of file diff --git a/fancy_rl/projections_new/frob_projection.py b/fancy_rl/projections_new/frob_projection.py new file mode 100644 index 0000000..ff21293 --- /dev/null +++ b/fancy_rl/projections_new/frob_projection.py @@ -0,0 +1,22 @@ +import torch +from .base_projection import BaseProjection +from typing import Dict + +class FrobProjection(BaseProjection): + def __init__(self, in_keys: list[str], out_keys: list[str], epsilon: float = 1e-3): + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.epsilon = epsilon + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + projected_params = {} + for key in policy_params.keys(): + old_param = old_policy_params[key] + new_param = policy_params[key] + diff = new_param - old_param + norm = torch.norm(diff) + if norm > self.epsilon: + projected_param = old_param + (self.epsilon / norm) * diff + else: + projected_param = new_param + projected_params[key] = projected_param + return projected_params \ No newline at end of file diff --git a/fancy_rl/projections_new/identity_projection.py b/fancy_rl/projections_new/identity_projection.py new file mode 100644 index 0000000..daad770 --- /dev/null +++ b/fancy_rl/projections_new/identity_projection.py @@ -0,0 +1,11 @@ +import torch +from .base_projection import BaseProjection +from typing import Dict + +class IdentityProjection(BaseProjection): + def __init__(self, in_keys: list[str], out_keys: list[str]): + super().__init__(in_keys=in_keys, out_keys=out_keys) + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + # The identity projection simply returns the new policy parameters without any modification + return policy_params \ No newline at end of file diff --git a/fancy_rl/projections_new/kl_projection.py b/fancy_rl/projections_new/kl_projection.py new file mode 100644 index 0000000..33724e5 --- /dev/null +++ b/fancy_rl/projections_new/kl_projection.py @@ -0,0 +1,33 @@ +import torch +from typing import Dict, List +from .base_projection import BaseProjection + +class KLProjection(BaseProjection): + def __init__( + self, + in_keys: List[str] = ["mean", "std"], + out_keys: List[str] = ["projected_mean", "projected_std"], + epsilon: float = 0.1 + ): + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.epsilon = epsilon + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + new_mean, new_std = policy_params["mean"], policy_params["std"] + old_mean, old_std = old_policy_params["mean"], old_policy_params["std"] + + diff = new_mean - old_mean + std_diff = new_std - old_std + + kl = 0.5 * (torch.sum(torch.square(diff / old_std), dim=-1) + + torch.sum(torch.square(std_diff / old_std), dim=-1) - + new_mean.shape[-1] + + torch.sum(torch.log(new_std / old_std), dim=-1)) + + factor = torch.sqrt(self.epsilon / (kl + 1e-8)) + factor = torch.clamp(factor, max=1.0) + + projected_mean = old_mean + factor.unsqueeze(-1) * diff + projected_std = old_std + factor.unsqueeze(-1) * std_diff + + return {"mean": projected_mean, "std": projected_std} \ No newline at end of file diff --git a/fancy_rl/projections_new/w2_projection.py b/fancy_rl/projections_new/w2_projection.py new file mode 100644 index 0000000..7b9977c --- /dev/null +++ b/fancy_rl/projections_new/w2_projection.py @@ -0,0 +1,73 @@ +import torch +from .base_projection import BaseProjection +from typing import Dict, Tuple +from torchrl.modules import TensorDictModule +from torchrl.distributions import TanhNormal, Delta + +class W2Projection(BaseProjection): + def __init__(self, + in_keys: list[str], + out_keys: list[str], + scale_prec: bool = False): + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.scale_prec = scale_prec + + def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + projected_params = {} + for key in policy_params.keys(): + if key.endswith('.loc'): + mean = policy_params[key] + old_mean = old_policy_params[key] + std_key = key.replace('.loc', '.scale') + std = policy_params[std_key] + old_std = old_policy_params[std_key] + + projected_mean, projected_std = self._trust_region_projection( + mean, std, old_mean, old_std + ) + + projected_params[key] = projected_mean + projected_params[std_key] = projected_std + elif not key.endswith('.scale'): + projected_params[key] = policy_params[key] + + return projected_params + + def _trust_region_projection(self, mean: torch.Tensor, std: torch.Tensor, + old_mean: torch.Tensor, old_std: torch.Tensor, + eps: float = 1e-3, eps_cov: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]: + mean_part, cov_part = self._gaussian_wasserstein_commutative(mean, std, old_mean, old_std) + + # Project mean + mean_mask = mean_part > eps + proj_mean = torch.where(mean_mask, + old_mean + (mean - old_mean) * torch.sqrt(eps / mean_part)[..., None], + mean) + + # Project covariance + cov_mask = cov_part > eps_cov + eta = torch.ones_like(cov_part) + eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / eps_cov) - 1. + eta = torch.clamp(eta, -0.9, float('inf')) # Avoid negative values that could lead to invalid standard deviations + + proj_std = (std + eta[..., None] * old_std) / (1. + eta[..., None] + 1e-8) + + return proj_mean, proj_std + + def _gaussian_wasserstein_commutative(self, mean: torch.Tensor, std: torch.Tensor, + old_mean: torch.Tensor, old_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.scale_prec: + # Mahalanobis distance for mean + mean_part = ((mean - old_mean) ** 2 / (old_std ** 2 + 1e-8)).sum(-1) + else: + # Euclidean distance for mean + mean_part = ((mean - old_mean) ** 2).sum(-1) + + # W2 objective for covariance + cov_part = (std ** 2 + old_std ** 2 - 2 * std * old_std).sum(-1) + + return mean_part, cov_part + + @classmethod + def make(cls, in_keys: list[str], out_keys: list[str], **kwargs) -> 'W2Projection': + return cls(in_keys=in_keys, out_keys=out_keys, **kwargs) \ No newline at end of file