New proj implementations
This commit is contained in:
parent
4240f611ac
commit
5f279beccf
13
fancy_rl/projections_new/__init__.py
Normal file
13
fancy_rl/projections_new/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from .base_projection import BaseProjection
|
||||||
|
from .kl_projection import KLProjection
|
||||||
|
from .w2_projection import W2Projection
|
||||||
|
from .frob_projection import FrobProjection
|
||||||
|
from .identity_projection import IdentityProjection
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseProjection",
|
||||||
|
"KLProjection",
|
||||||
|
"W2Projection",
|
||||||
|
"FrobProjection",
|
||||||
|
"IdentityProjection"
|
||||||
|
]
|
51
fancy_rl/projections_new/base_projection.py
Normal file
51
fancy_rl/projections_new/base_projection.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
from torchrl.modules import TensorDictModule
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
class BaseProjection(TensorDictModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_keys: List[str],
|
||||||
|
out_keys: List[str],
|
||||||
|
):
|
||||||
|
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||||
|
|
||||||
|
def forward(self, tensordict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
mean, std = self.in_keys
|
||||||
|
projected_mean, projected_std = self.out_keys
|
||||||
|
|
||||||
|
old_mean = tensordict[mean]
|
||||||
|
old_std = tensordict[std]
|
||||||
|
|
||||||
|
new_mean = tensordict.get(projected_mean, old_mean)
|
||||||
|
new_std = tensordict.get(projected_std, old_std)
|
||||||
|
|
||||||
|
projected_params = self.project(
|
||||||
|
{"mean": new_mean, "std": new_std},
|
||||||
|
{"mean": old_mean, "std": old_std}
|
||||||
|
)
|
||||||
|
|
||||||
|
tensordict[projected_mean] = projected_params["mean"]
|
||||||
|
tensordict[projected_std] = projected_params["std"]
|
||||||
|
|
||||||
|
return tensordict
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
raise NotImplementedError("Subclasses must implement the project method")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make(cls, projection_type: str, **kwargs: Any) -> 'BaseProjection':
|
||||||
|
if projection_type == "kl":
|
||||||
|
from .kl_projection import KLProjection
|
||||||
|
return KLProjection(**kwargs)
|
||||||
|
elif projection_type == "w2":
|
||||||
|
from .w2_projection import W2Projection
|
||||||
|
return W2Projection(**kwargs)
|
||||||
|
elif projection_type == "frob":
|
||||||
|
from .frob_projection import FrobProjection
|
||||||
|
return FrobProjection(**kwargs)
|
||||||
|
elif projection_type == "identity":
|
||||||
|
from .identity_projection import IdentityProjection
|
||||||
|
return IdentityProjection(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown projection type: {projection_type}")
|
22
fancy_rl/projections_new/frob_projection.py
Normal file
22
fancy_rl/projections_new/frob_projection.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
class FrobProjection(BaseProjection):
|
||||||
|
def __init__(self, in_keys: list[str], out_keys: list[str], epsilon: float = 1e-3):
|
||||||
|
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
projected_params = {}
|
||||||
|
for key in policy_params.keys():
|
||||||
|
old_param = old_policy_params[key]
|
||||||
|
new_param = policy_params[key]
|
||||||
|
diff = new_param - old_param
|
||||||
|
norm = torch.norm(diff)
|
||||||
|
if norm > self.epsilon:
|
||||||
|
projected_param = old_param + (self.epsilon / norm) * diff
|
||||||
|
else:
|
||||||
|
projected_param = new_param
|
||||||
|
projected_params[key] = projected_param
|
||||||
|
return projected_params
|
11
fancy_rl/projections_new/identity_projection.py
Normal file
11
fancy_rl/projections_new/identity_projection.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import torch
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
class IdentityProjection(BaseProjection):
|
||||||
|
def __init__(self, in_keys: list[str], out_keys: list[str]):
|
||||||
|
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
# The identity projection simply returns the new policy parameters without any modification
|
||||||
|
return policy_params
|
33
fancy_rl/projections_new/kl_projection.py
Normal file
33
fancy_rl/projections_new/kl_projection.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Dict, List
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
|
||||||
|
class KLProjection(BaseProjection):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_keys: List[str] = ["mean", "std"],
|
||||||
|
out_keys: List[str] = ["projected_mean", "projected_std"],
|
||||||
|
epsilon: float = 0.1
|
||||||
|
):
|
||||||
|
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
new_mean, new_std = policy_params["mean"], policy_params["std"]
|
||||||
|
old_mean, old_std = old_policy_params["mean"], old_policy_params["std"]
|
||||||
|
|
||||||
|
diff = new_mean - old_mean
|
||||||
|
std_diff = new_std - old_std
|
||||||
|
|
||||||
|
kl = 0.5 * (torch.sum(torch.square(diff / old_std), dim=-1) +
|
||||||
|
torch.sum(torch.square(std_diff / old_std), dim=-1) -
|
||||||
|
new_mean.shape[-1] +
|
||||||
|
torch.sum(torch.log(new_std / old_std), dim=-1))
|
||||||
|
|
||||||
|
factor = torch.sqrt(self.epsilon / (kl + 1e-8))
|
||||||
|
factor = torch.clamp(factor, max=1.0)
|
||||||
|
|
||||||
|
projected_mean = old_mean + factor.unsqueeze(-1) * diff
|
||||||
|
projected_std = old_std + factor.unsqueeze(-1) * std_diff
|
||||||
|
|
||||||
|
return {"mean": projected_mean, "std": projected_std}
|
73
fancy_rl/projections_new/w2_projection.py
Normal file
73
fancy_rl/projections_new/w2_projection.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import torch
|
||||||
|
from .base_projection import BaseProjection
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
from torchrl.modules import TensorDictModule
|
||||||
|
from torchrl.distributions import TanhNormal, Delta
|
||||||
|
|
||||||
|
class W2Projection(BaseProjection):
|
||||||
|
def __init__(self,
|
||||||
|
in_keys: list[str],
|
||||||
|
out_keys: list[str],
|
||||||
|
scale_prec: bool = False):
|
||||||
|
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||||
|
self.scale_prec = scale_prec
|
||||||
|
|
||||||
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
projected_params = {}
|
||||||
|
for key in policy_params.keys():
|
||||||
|
if key.endswith('.loc'):
|
||||||
|
mean = policy_params[key]
|
||||||
|
old_mean = old_policy_params[key]
|
||||||
|
std_key = key.replace('.loc', '.scale')
|
||||||
|
std = policy_params[std_key]
|
||||||
|
old_std = old_policy_params[std_key]
|
||||||
|
|
||||||
|
projected_mean, projected_std = self._trust_region_projection(
|
||||||
|
mean, std, old_mean, old_std
|
||||||
|
)
|
||||||
|
|
||||||
|
projected_params[key] = projected_mean
|
||||||
|
projected_params[std_key] = projected_std
|
||||||
|
elif not key.endswith('.scale'):
|
||||||
|
projected_params[key] = policy_params[key]
|
||||||
|
|
||||||
|
return projected_params
|
||||||
|
|
||||||
|
def _trust_region_projection(self, mean: torch.Tensor, std: torch.Tensor,
|
||||||
|
old_mean: torch.Tensor, old_std: torch.Tensor,
|
||||||
|
eps: float = 1e-3, eps_cov: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
mean_part, cov_part = self._gaussian_wasserstein_commutative(mean, std, old_mean, old_std)
|
||||||
|
|
||||||
|
# Project mean
|
||||||
|
mean_mask = mean_part > eps
|
||||||
|
proj_mean = torch.where(mean_mask,
|
||||||
|
old_mean + (mean - old_mean) * torch.sqrt(eps / mean_part)[..., None],
|
||||||
|
mean)
|
||||||
|
|
||||||
|
# Project covariance
|
||||||
|
cov_mask = cov_part > eps_cov
|
||||||
|
eta = torch.ones_like(cov_part)
|
||||||
|
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / eps_cov) - 1.
|
||||||
|
eta = torch.clamp(eta, -0.9, float('inf')) # Avoid negative values that could lead to invalid standard deviations
|
||||||
|
|
||||||
|
proj_std = (std + eta[..., None] * old_std) / (1. + eta[..., None] + 1e-8)
|
||||||
|
|
||||||
|
return proj_mean, proj_std
|
||||||
|
|
||||||
|
def _gaussian_wasserstein_commutative(self, mean: torch.Tensor, std: torch.Tensor,
|
||||||
|
old_mean: torch.Tensor, old_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.scale_prec:
|
||||||
|
# Mahalanobis distance for mean
|
||||||
|
mean_part = ((mean - old_mean) ** 2 / (old_std ** 2 + 1e-8)).sum(-1)
|
||||||
|
else:
|
||||||
|
# Euclidean distance for mean
|
||||||
|
mean_part = ((mean - old_mean) ** 2).sum(-1)
|
||||||
|
|
||||||
|
# W2 objective for covariance
|
||||||
|
cov_part = (std ** 2 + old_std ** 2 - 2 * std * old_std).sum(-1)
|
||||||
|
|
||||||
|
return mean_part, cov_part
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make(cls, in_keys: list[str], out_keys: list[str], **kwargs) -> 'W2Projection':
|
||||||
|
return cls(in_keys=in_keys, out_keys=out_keys, **kwargs)
|
Loading…
Reference in New Issue
Block a user