Fixing issues with projections

This commit is contained in:
Dominik Moritz Roth 2024-10-21 15:23:17 +02:00
parent 71cb8593d9
commit 651ef1522f
5 changed files with 199 additions and 108 deletions

View File

@ -1,16 +1,71 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch import torch
from typing import Dict from torch import nn
from typing import Dict, List
class BaseProjection(ABC, torch.nn.Module): class BaseProjection(nn.Module, ABC):
def __init__(self, in_keys: list[str], out_keys: list[str]): def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__() super().__init__()
self._validate_in_keys(in_keys)
self._validate_out_keys(out_keys)
self.in_keys = in_keys self.in_keys = in_keys
self.out_keys = out_keys self.out_keys = out_keys
self.trust_region_coeff = trust_region_coeff
self.mean_bound = mean_bound
self.cov_bound = cov_bound
self.full_cov = "scale_tril" in in_keys
self.contextual_std = contextual_std
def _validate_in_keys(self, keys: List[str]):
valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"}
if not set(keys).issubset(valid_keys):
raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}")
if "loc" not in keys or "old_loc" not in keys:
raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys")
if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys):
raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'")
def _validate_out_keys(self, keys: List[str]):
valid_keys = {"loc", "scale", "scale_tril"}
if not set(keys).issubset(valid_keys):
raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}")
if "loc" not in keys:
raise ValueError("'loc' must be included in out_keys")
if "scale" not in keys and "scale_tril" not in keys:
raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys")
@abstractmethod @abstractmethod
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pass pass
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: @abstractmethod
return self.project(policy_params, old_policy_params) def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
pass
def forward(self, tensordict):
policy_params = {}
old_policy_params = {}
for key in self.in_keys:
if key not in tensordict:
raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}")
if key.startswith("old_"):
old_policy_params[key[4:]] = tensordict[key]
else:
policy_params[key] = tensordict[key]
projected_params = self.project(policy_params, old_policy_params)
return projected_params
def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor:
if not self.full_cov:
return torch.diag_embed(params["scale"].pow(2))
else:
return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2))
def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor:
if not self.full_cov:
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
else:
return torch.linalg.cholesky(cov)

View File

@ -1,33 +1,34 @@
import torch import torch
from .base_projection import BaseProjection from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict from typing import Dict
class FrobeniusProjection(BaseProjection): class FrobeniusProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.scale_prec = scale_prec self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, chol = policy_params["loc"], policy_params["scale_tril"] mean = policy_params["loc"]
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"] old_mean = old_policy_params["loc"]
cov = torch.matmul(chol, chol.transpose(-1, -2)) cov = self._calc_covariance(policy_params)
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2)) old_cov = self._calc_covariance(old_policy_params)
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part) proj_cov = self._cov_projection(cov, old_cov, cov_part)
proj_chol = torch.linalg.cholesky(proj_cov) scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
return {"loc": proj_mean, "scale_tril": proj_chol} return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, chol = policy_params["loc"], policy_params["scale_tril"] mean = policy_params["loc"]
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"] proj_mean = proj_policy_params["loc"]
cov = torch.matmul(chol, chol.transpose(-1, -2)) cov = self._calc_covariance(policy_params)
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2)) proj_cov = self._calc_covariance(proj_policy_params)
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1) mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1)) cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))

View File

@ -3,8 +3,8 @@ from .base_projection import BaseProjection
from typing import Dict from typing import Dict
class IdentityProjection(BaseProjection): class IdentityProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01): def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return policy_params return policy_params

View File

@ -2,6 +2,7 @@ import torch
import cpp_projection import cpp_projection
import numpy as np import numpy as np
from .base_projection import BaseProjection from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict, Tuple, Any from typing import Dict, Tuple, Any
MAX_EVAL = 1000 MAX_EVAL = 1000
@ -10,57 +11,65 @@ def get_numpy(tensor):
return tensor.detach().cpu().numpy() return tensor.detach().cpu().numpy()
class KLProjection(BaseProjection): class KLProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True): def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.is_diag = is_diag
self.contextual_std = contextual_std
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, std = policy_params["loc"], policy_params["scale_tril"] mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"] old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std)) mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
if not self.contextual_std: if not self.contextual_std:
std = std[:1] scale_or_tril = scale_or_tril[:1]
old_std = old_std[:1] old_scale_or_tril = old_scale_or_tril[:1]
cov_part = cov_part[:1] cov_part = cov_part[:1]
proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_std = self._cov_projection(std, old_std, cov_part) proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
if not self.contextual_std: if not self.contextual_std:
proj_std = proj_std.expand(mean.shape[0], -1, -1) proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
return {"loc": proj_mean, "scale_tril": proj_std} return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, std = policy_params["loc"], policy_params["scale_tril"] mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"] proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]]
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std))) kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
return kl.mean() * self.trust_region_coeff return kl.mean() * self.trust_region_coeff
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
mean, std = p mean, scale_or_tril = p
mean_other, std_other = q mean_other, scale_or_tril_other = q
k = mean.shape[-1] k = mean.shape[-1]
maha_part = 0.5 * self._maha(mean, mean_other, std_other) maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
det_term = self._log_determinant(std) det_term = self._log_determinant(scale_or_tril)
det_term_other = self._log_determinant(std_other) det_term_other = self._log_determinant(scale_or_tril_other)
if self.full_cov:
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False))
else:
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
cov_part = 0.5 * (trace_part - k + det_term_other - det_term) cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
return maha_part, cov_part return maha_part, cov_part
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor: def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
diff = x - y diff = x - y
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1) if self.full_cov:
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1)
else:
return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor: def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1) if self.full_cov:
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1)
else:
return 2 * torch.log(scale_or_tril).sum(-1)
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
return torch.sum(x.pow(2), dim=(-2, -1)) return torch.sum(x.pow(2), dim=(-2, -1))
@ -68,49 +77,45 @@ class KLProjection(BaseProjection):
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1) return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
cov = torch.matmul(std, std.transpose(-1, -2)) if self.full_cov:
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2)) cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
if self.is_diag:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
try:
if mask.any():
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
old_cov.diagonal(dim1=-2, dim2=-1),
self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
mask &= ~is_invalid
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
except Exception as e:
proj_std = old_std
else: else:
try: cov = scale_or_tril.pow(2)
old_cov = old_scale_or_tril.pow(2)
mask = cov_part > self.cov_bound mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std) proj_scale_or_tril = torch.zeros_like(scale_or_tril)
proj_std[~mask] = std[~mask] proj_scale_or_tril[~mask] = scale_or_tril[~mask]
try:
if mask.any(): if mask.any():
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound) if self.full_cov:
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound)
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
if is_invalid.any(): if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid] proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
mask &= ~is_invalid mask &= ~is_invalid
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
failed_mask = failed_mask.bool() failed_mask = failed_mask.bool()
if failed_mask.any(): if failed_mask.any():
proj_std[failed_mask] = old_std[failed_mask] proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask]
else:
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
mask &= ~is_invalid
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
except Exception as e: except Exception as e:
import logging import logging
logging.error('Projection failed, taking old cholesky for projection.') logging.error('Projection failed, taking old scale_or_tril for projection.')
print("Projection failed, taking old cholesky for projection.") print("Projection failed, taking old scale_or_tril for projection.")
proj_std = old_std proj_scale_or_tril = old_scale_or_tril
raise e raise e
return proj_std return proj_scale_or_tril
class KLProjectionGradFunctionCovOnly(torch.autograd.Function): class KLProjectionGradFunctionCovOnly(torch.autograd.Function):

View File

@ -1,56 +1,86 @@
import torch import torch
from .base_projection import BaseProjection from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict, Tuple from typing import Dict, Tuple
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
"""
'Converts' scale_tril to scale_sqrt.
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
"""
return scale_tril
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor], def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]: q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
mean, sqrt = p mean, scale_or_sqrt = p
mean_other, sqrt_other = q mean_other, scale_or_sqrt_other = q
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2)) if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2)) cov = scale_or_sqrt.pow(2)
cov_other = scale_or_sqrt_other.pow(2)
if scale_prec: if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device) identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity) sqrt_inv_other = 1 / scale_or_sqrt_other
c = sqrt_inv_other @ cov @ sqrt_inv_other c = sqrt_inv_other.pow(2) * cov
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt) cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1)
else: else:
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt) cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1)
else: # Full covariance case
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity)
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2)
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
else:
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
return mean_part, cov_part return mean_part, cov_part
class WassersteinProjection(BaseProjection): class WassersteinProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.scale_prec = scale_prec self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"] mean = policy_params["loc"]
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"] old_mean = old_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]])
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec) mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec)
proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part) proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
return {"loc": proj_mean, "scale_tril": proj_sqrt} return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"] mean = policy_params["loc"]
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"] proj_mean = proj_policy_params["loc"]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec) scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
w2 = mean_part + cov_part w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean diff = mean - old_mean
norm = torch.norm(diff, dim=-1, keepdim=True) norm = torch.sqrt(mean_part)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean) return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
diff = sqrt - old_sqrt if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = torch.sqrt(cov_part)
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
else: # Full covariance case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = torch.norm(diff, dim=(-2, -1), keepdim=True) norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt) return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt)