New proj implementations
This commit is contained in:
parent
4240f611ac
commit
5f279beccf
13
fancy_rl/projections_new/__init__.py
Normal file
13
fancy_rl/projections_new/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .base_projection import BaseProjection
|
||||
from .kl_projection import KLProjection
|
||||
from .w2_projection import W2Projection
|
||||
from .frob_projection import FrobProjection
|
||||
from .identity_projection import IdentityProjection
|
||||
|
||||
__all__ = [
|
||||
"BaseProjection",
|
||||
"KLProjection",
|
||||
"W2Projection",
|
||||
"FrobProjection",
|
||||
"IdentityProjection"
|
||||
]
|
51
fancy_rl/projections_new/base_projection.py
Normal file
51
fancy_rl/projections_new/base_projection.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from torchrl.modules import TensorDictModule
|
||||
from typing import List, Dict, Any
|
||||
|
||||
class BaseProjection(TensorDictModule):
|
||||
def __init__(
|
||||
self,
|
||||
in_keys: List[str],
|
||||
out_keys: List[str],
|
||||
):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
|
||||
def forward(self, tensordict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean, std = self.in_keys
|
||||
projected_mean, projected_std = self.out_keys
|
||||
|
||||
old_mean = tensordict[mean]
|
||||
old_std = tensordict[std]
|
||||
|
||||
new_mean = tensordict.get(projected_mean, old_mean)
|
||||
new_std = tensordict.get(projected_std, old_std)
|
||||
|
||||
projected_params = self.project(
|
||||
{"mean": new_mean, "std": new_std},
|
||||
{"mean": old_mean, "std": old_std}
|
||||
)
|
||||
|
||||
tensordict[projected_mean] = projected_params["mean"]
|
||||
tensordict[projected_std] = projected_params["std"]
|
||||
|
||||
return tensordict
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError("Subclasses must implement the project method")
|
||||
|
||||
@classmethod
|
||||
def make(cls, projection_type: str, **kwargs: Any) -> 'BaseProjection':
|
||||
if projection_type == "kl":
|
||||
from .kl_projection import KLProjection
|
||||
return KLProjection(**kwargs)
|
||||
elif projection_type == "w2":
|
||||
from .w2_projection import W2Projection
|
||||
return W2Projection(**kwargs)
|
||||
elif projection_type == "frob":
|
||||
from .frob_projection import FrobProjection
|
||||
return FrobProjection(**kwargs)
|
||||
elif projection_type == "identity":
|
||||
from .identity_projection import IdentityProjection
|
||||
return IdentityProjection(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown projection type: {projection_type}")
|
22
fancy_rl/projections_new/frob_projection.py
Normal file
22
fancy_rl/projections_new/frob_projection.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
|
||||
class FrobProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], epsilon: float = 1e-3):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
projected_params = {}
|
||||
for key in policy_params.keys():
|
||||
old_param = old_policy_params[key]
|
||||
new_param = policy_params[key]
|
||||
diff = new_param - old_param
|
||||
norm = torch.norm(diff)
|
||||
if norm > self.epsilon:
|
||||
projected_param = old_param + (self.epsilon / norm) * diff
|
||||
else:
|
||||
projected_param = new_param
|
||||
projected_params[key] = projected_param
|
||||
return projected_params
|
11
fancy_rl/projections_new/identity_projection.py
Normal file
11
fancy_rl/projections_new/identity_projection.py
Normal file
@ -0,0 +1,11 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
|
||||
class IdentityProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str]):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# The identity projection simply returns the new policy parameters without any modification
|
||||
return policy_params
|
33
fancy_rl/projections_new/kl_projection.py
Normal file
33
fancy_rl/projections_new/kl_projection.py
Normal file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from typing import Dict, List
|
||||
from .base_projection import BaseProjection
|
||||
|
||||
class KLProjection(BaseProjection):
|
||||
def __init__(
|
||||
self,
|
||||
in_keys: List[str] = ["mean", "std"],
|
||||
out_keys: List[str] = ["projected_mean", "projected_std"],
|
||||
epsilon: float = 0.1
|
||||
):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
new_mean, new_std = policy_params["mean"], policy_params["std"]
|
||||
old_mean, old_std = old_policy_params["mean"], old_policy_params["std"]
|
||||
|
||||
diff = new_mean - old_mean
|
||||
std_diff = new_std - old_std
|
||||
|
||||
kl = 0.5 * (torch.sum(torch.square(diff / old_std), dim=-1) +
|
||||
torch.sum(torch.square(std_diff / old_std), dim=-1) -
|
||||
new_mean.shape[-1] +
|
||||
torch.sum(torch.log(new_std / old_std), dim=-1))
|
||||
|
||||
factor = torch.sqrt(self.epsilon / (kl + 1e-8))
|
||||
factor = torch.clamp(factor, max=1.0)
|
||||
|
||||
projected_mean = old_mean + factor.unsqueeze(-1) * diff
|
||||
projected_std = old_std + factor.unsqueeze(-1) * std_diff
|
||||
|
||||
return {"mean": projected_mean, "std": projected_std}
|
73
fancy_rl/projections_new/w2_projection.py
Normal file
73
fancy_rl/projections_new/w2_projection.py
Normal file
@ -0,0 +1,73 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict, Tuple
|
||||
from torchrl.modules import TensorDictModule
|
||||
from torchrl.distributions import TanhNormal, Delta
|
||||
|
||||
class W2Projection(BaseProjection):
|
||||
def __init__(self,
|
||||
in_keys: list[str],
|
||||
out_keys: list[str],
|
||||
scale_prec: bool = False):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
self.scale_prec = scale_prec
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
projected_params = {}
|
||||
for key in policy_params.keys():
|
||||
if key.endswith('.loc'):
|
||||
mean = policy_params[key]
|
||||
old_mean = old_policy_params[key]
|
||||
std_key = key.replace('.loc', '.scale')
|
||||
std = policy_params[std_key]
|
||||
old_std = old_policy_params[std_key]
|
||||
|
||||
projected_mean, projected_std = self._trust_region_projection(
|
||||
mean, std, old_mean, old_std
|
||||
)
|
||||
|
||||
projected_params[key] = projected_mean
|
||||
projected_params[std_key] = projected_std
|
||||
elif not key.endswith('.scale'):
|
||||
projected_params[key] = policy_params[key]
|
||||
|
||||
return projected_params
|
||||
|
||||
def _trust_region_projection(self, mean: torch.Tensor, std: torch.Tensor,
|
||||
old_mean: torch.Tensor, old_std: torch.Tensor,
|
||||
eps: float = 1e-3, eps_cov: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean_part, cov_part = self._gaussian_wasserstein_commutative(mean, std, old_mean, old_std)
|
||||
|
||||
# Project mean
|
||||
mean_mask = mean_part > eps
|
||||
proj_mean = torch.where(mean_mask,
|
||||
old_mean + (mean - old_mean) * torch.sqrt(eps / mean_part)[..., None],
|
||||
mean)
|
||||
|
||||
# Project covariance
|
||||
cov_mask = cov_part > eps_cov
|
||||
eta = torch.ones_like(cov_part)
|
||||
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / eps_cov) - 1.
|
||||
eta = torch.clamp(eta, -0.9, float('inf')) # Avoid negative values that could lead to invalid standard deviations
|
||||
|
||||
proj_std = (std + eta[..., None] * old_std) / (1. + eta[..., None] + 1e-8)
|
||||
|
||||
return proj_mean, proj_std
|
||||
|
||||
def _gaussian_wasserstein_commutative(self, mean: torch.Tensor, std: torch.Tensor,
|
||||
old_mean: torch.Tensor, old_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.scale_prec:
|
||||
# Mahalanobis distance for mean
|
||||
mean_part = ((mean - old_mean) ** 2 / (old_std ** 2 + 1e-8)).sum(-1)
|
||||
else:
|
||||
# Euclidean distance for mean
|
||||
mean_part = ((mean - old_mean) ** 2).sum(-1)
|
||||
|
||||
# W2 objective for covariance
|
||||
cov_part = (std ** 2 + old_std ** 2 - 2 * std * old_std).sum(-1)
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
@classmethod
|
||||
def make(cls, in_keys: list[str], out_keys: list[str], **kwargs) -> 'W2Projection':
|
||||
return cls(in_keys=in_keys, out_keys=out_keys, **kwargs)
|
Loading…
Reference in New Issue
Block a user