New proj implementations

This commit is contained in:
Dominik Moritz Roth 2024-07-17 14:51:59 +02:00
parent 4240f611ac
commit 5f279beccf
6 changed files with 203 additions and 0 deletions

View File

@ -0,0 +1,13 @@
from .base_projection import BaseProjection
from .kl_projection import KLProjection
from .w2_projection import W2Projection
from .frob_projection import FrobProjection
from .identity_projection import IdentityProjection
__all__ = [
"BaseProjection",
"KLProjection",
"W2Projection",
"FrobProjection",
"IdentityProjection"
]

View File

@ -0,0 +1,51 @@
import torch
from torchrl.modules import TensorDictModule
from typing import List, Dict, Any
class BaseProjection(TensorDictModule):
def __init__(
self,
in_keys: List[str],
out_keys: List[str],
):
super().__init__(in_keys=in_keys, out_keys=out_keys)
def forward(self, tensordict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, std = self.in_keys
projected_mean, projected_std = self.out_keys
old_mean = tensordict[mean]
old_std = tensordict[std]
new_mean = tensordict.get(projected_mean, old_mean)
new_std = tensordict.get(projected_std, old_std)
projected_params = self.project(
{"mean": new_mean, "std": new_std},
{"mean": old_mean, "std": old_std}
)
tensordict[projected_mean] = projected_params["mean"]
tensordict[projected_std] = projected_params["std"]
return tensordict
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
raise NotImplementedError("Subclasses must implement the project method")
@classmethod
def make(cls, projection_type: str, **kwargs: Any) -> 'BaseProjection':
if projection_type == "kl":
from .kl_projection import KLProjection
return KLProjection(**kwargs)
elif projection_type == "w2":
from .w2_projection import W2Projection
return W2Projection(**kwargs)
elif projection_type == "frob":
from .frob_projection import FrobProjection
return FrobProjection(**kwargs)
elif projection_type == "identity":
from .identity_projection import IdentityProjection
return IdentityProjection(**kwargs)
else:
raise ValueError(f"Unknown projection type: {projection_type}")

View File

@ -0,0 +1,22 @@
import torch
from .base_projection import BaseProjection
from typing import Dict
class FrobProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], epsilon: float = 1e-3):
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.epsilon = epsilon
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
projected_params = {}
for key in policy_params.keys():
old_param = old_policy_params[key]
new_param = policy_params[key]
diff = new_param - old_param
norm = torch.norm(diff)
if norm > self.epsilon:
projected_param = old_param + (self.epsilon / norm) * diff
else:
projected_param = new_param
projected_params[key] = projected_param
return projected_params

View File

@ -0,0 +1,11 @@
import torch
from .base_projection import BaseProjection
from typing import Dict
class IdentityProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str]):
super().__init__(in_keys=in_keys, out_keys=out_keys)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# The identity projection simply returns the new policy parameters without any modification
return policy_params

View File

@ -0,0 +1,33 @@
import torch
from typing import Dict, List
from .base_projection import BaseProjection
class KLProjection(BaseProjection):
def __init__(
self,
in_keys: List[str] = ["mean", "std"],
out_keys: List[str] = ["projected_mean", "projected_std"],
epsilon: float = 0.1
):
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.epsilon = epsilon
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
new_mean, new_std = policy_params["mean"], policy_params["std"]
old_mean, old_std = old_policy_params["mean"], old_policy_params["std"]
diff = new_mean - old_mean
std_diff = new_std - old_std
kl = 0.5 * (torch.sum(torch.square(diff / old_std), dim=-1) +
torch.sum(torch.square(std_diff / old_std), dim=-1) -
new_mean.shape[-1] +
torch.sum(torch.log(new_std / old_std), dim=-1))
factor = torch.sqrt(self.epsilon / (kl + 1e-8))
factor = torch.clamp(factor, max=1.0)
projected_mean = old_mean + factor.unsqueeze(-1) * diff
projected_std = old_std + factor.unsqueeze(-1) * std_diff
return {"mean": projected_mean, "std": projected_std}

View File

@ -0,0 +1,73 @@
import torch
from .base_projection import BaseProjection
from typing import Dict, Tuple
from torchrl.modules import TensorDictModule
from torchrl.distributions import TanhNormal, Delta
class W2Projection(BaseProjection):
def __init__(self,
in_keys: list[str],
out_keys: list[str],
scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
projected_params = {}
for key in policy_params.keys():
if key.endswith('.loc'):
mean = policy_params[key]
old_mean = old_policy_params[key]
std_key = key.replace('.loc', '.scale')
std = policy_params[std_key]
old_std = old_policy_params[std_key]
projected_mean, projected_std = self._trust_region_projection(
mean, std, old_mean, old_std
)
projected_params[key] = projected_mean
projected_params[std_key] = projected_std
elif not key.endswith('.scale'):
projected_params[key] = policy_params[key]
return projected_params
def _trust_region_projection(self, mean: torch.Tensor, std: torch.Tensor,
old_mean: torch.Tensor, old_std: torch.Tensor,
eps: float = 1e-3, eps_cov: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
mean_part, cov_part = self._gaussian_wasserstein_commutative(mean, std, old_mean, old_std)
# Project mean
mean_mask = mean_part > eps
proj_mean = torch.where(mean_mask,
old_mean + (mean - old_mean) * torch.sqrt(eps / mean_part)[..., None],
mean)
# Project covariance
cov_mask = cov_part > eps_cov
eta = torch.ones_like(cov_part)
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / eps_cov) - 1.
eta = torch.clamp(eta, -0.9, float('inf')) # Avoid negative values that could lead to invalid standard deviations
proj_std = (std + eta[..., None] * old_std) / (1. + eta[..., None] + 1e-8)
return proj_mean, proj_std
def _gaussian_wasserstein_commutative(self, mean: torch.Tensor, std: torch.Tensor,
old_mean: torch.Tensor, old_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.scale_prec:
# Mahalanobis distance for mean
mean_part = ((mean - old_mean) ** 2 / (old_std ** 2 + 1e-8)).sum(-1)
else:
# Euclidean distance for mean
mean_part = ((mean - old_mean) ** 2).sum(-1)
# W2 objective for covariance
cov_part = (std ** 2 + old_std ** 2 - 2 * std * old_std).sum(-1)
return mean_part, cov_part
@classmethod
def make(cls, in_keys: list[str], out_keys: list[str], **kwargs) -> 'W2Projection':
return cls(in_keys=in_keys, out_keys=out_keys, **kwargs)