Port soem fixes / additions learned from implementign itpal_jax

This commit is contained in:
Dominik Moritz Roth 2025-01-22 14:03:23 +01:00
parent 3816adef9a
commit ecf0b72e88
4 changed files with 240 additions and 103 deletions

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch import torch
from torch import nn from torch import nn
from typing import Dict, List from typing import Dict, List, Tuple
class BaseProjection(nn.Module, ABC): class BaseProjection(nn.Module, ABC):
def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
@ -68,4 +68,25 @@ class BaseProjection(nn.Module, ABC):
if not self.full_cov: if not self.full_cov:
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1)) return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
else: else:
return torch.linalg.cholesky(cov) return torch.linalg.cholesky(cov)
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
"""Project mean based on the Mahalanobis objective and trust region.
Args:
mean: Current mean vectors
old_mean: Old mean vectors
mean_part: Mahalanobis/Euclidean distance between the two mean vectors
Returns:
Projected mean that satisfies the trust region
"""
mask = mean_part > self.mean_bound
omega = torch.ones_like(mean_part, device=mean.device)
omega = torch.where(mask, torch.sqrt(mean_part / self.mean_bound) - 1., omega)
omega = torch.maximum(-omega, omega)[..., None]
# Use matrix operations instead of boolean indexing
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
mask_matrix = mask[..., None].to(mean.dtype)
return mask_matrix * m + (1 - mask_matrix) * mean

View File

@ -1,7 +1,7 @@
import torch import torch
from .base_projection import BaseProjection from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from typing import Dict from typing import Dict, Tuple
class FrobeniusProjection(BaseProjection): class FrobeniusProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
@ -12,16 +12,23 @@ class FrobeniusProjection(BaseProjection):
mean = policy_params["loc"] mean = policy_params["loc"]
old_mean = old_policy_params["loc"] old_mean = old_policy_params["loc"]
# Convert to covariance representation
cov = self._calc_covariance(policy_params) cov = self._calc_covariance(policy_params)
old_cov = self._calc_covariance(old_policy_params) old_cov = self._calc_covariance(old_policy_params)
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) if not self.contextual_std:
cov = cov[:1]
old_cov = old_cov[:1]
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part) proj_cov = self._cov_projection(cov, old_cov, cov_part)
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov) scale_or_tril = self._calc_scale_or_scale_tril(proj_cov)
return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril} if not self.contextual_std:
scale_or_tril = scale_or_tril.expand(mean.shape[0], *scale_or_tril.shape[1:])
return {"loc": proj_mean, self.out_keys[1]: scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean = policy_params["loc"] mean = policy_params["loc"]
@ -35,34 +42,48 @@ class FrobeniusProjection(BaseProjection):
return (mean_diff + cov_diff).mean() * self.trust_region_coeff return (mean_diff + cov_diff).mean() * self.trust_region_coeff
def _gaussian_frobenius(self, p, q): def _gaussian_frobenius(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
mean, cov = p mean, cov = p
old_mean, old_cov = q old_mean, old_cov = q
if self.scale_prec: if self.scale_prec:
prec_old = torch.inverse(old_cov) if self.full_cov:
mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1) # Use triangular solve instead of inverse for stability
cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1] diff = mean - old_mean
solved = torch.triangular_solve(diff.unsqueeze(-1), old_cov, upper=False)[0].squeeze(-1)
mean_part = torch.sum(torch.square(solved), dim=-1)
else:
# Diagonal case - direct division is stable
mean_part = torch.sum(torch.square((mean - old_mean) / old_cov), dim=-1)
else: else:
mean_part = torch.sum(torch.square(mean - old_mean), dim=-1) mean_part = torch.sum(torch.square(mean - old_mean), dim=-1)
cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1))
# Covariance part
if self.full_cov:
diff = old_cov - cov
cov_part = torch.sum(torch.square(diff), dim=(-2, -1))
else:
diff = old_cov - cov
cov_part = torch.sum(torch.square(diff), dim=-1)
return mean_part, cov_part return mean_part, cov_part
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean
norm = torch.sqrt(mean_part)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
batch_shape = cov.shape[:-2] batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
cov_mask = cov_part > self.cov_bound cov_mask = cov_part > self.cov_bound
eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device) eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device)
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1. eta = torch.where(cov_mask, torch.sqrt(cov_part / self.cov_bound) - 1., eta)
eta = torch.max(-eta, eta) eta = torch.maximum(-eta, eta)
new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None] if self.full_cov:
proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov) new_cov = (cov + torch.einsum('...,...ij->...ij', eta, old_cov)) / \
(1. + eta + 1e-16)[..., None, None]
mask_matrix = cov_mask[..., None, None].to(cov.dtype)
proj_cov = torch.where(mask_matrix, new_cov, cov)
else:
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
mask_matrix = cov_mask[..., None].to(cov.dtype)
proj_cov = torch.where(mask_matrix, new_cov, cov)
return proj_cov return proj_cov

View File

@ -19,8 +19,17 @@ class KLProjection(BaseProjection):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] self._validate_inputs(policy_params, old_policy_params)
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
if self.full_cov:
scale_or_tril = policy_params["scale_tril"]
old_scale_or_tril = old_policy_params["scale_tril"]
else:
scale_or_tril = policy_params["scale"]
old_scale_or_tril = old_policy_params["scale"]
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril)) mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
@ -35,11 +44,22 @@ class KLProjection(BaseProjection):
if not self.contextual_std: if not self.contextual_std:
proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:]) proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril} if self.full_cov:
return {"loc": proj_mean, "scale_tril": proj_scale_or_tril}
else:
return {"loc": proj_mean, "scale": proj_scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] mean = policy_params["loc"]
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]] proj_mean = proj_policy_params["loc"]
if self.full_cov:
scale_or_tril = policy_params["scale_tril"]
proj_scale_or_tril = proj_policy_params["scale_tril"]
else:
scale_or_tril = policy_params["scale"]
proj_scale_or_tril = proj_policy_params["scale"]
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
return kl.mean() * self.trust_region_coeff return kl.mean() * self.trust_region_coeff
@ -54,7 +74,9 @@ class KLProjection(BaseProjection):
det_term_other = self._log_determinant(scale_or_tril_other) det_term_other = self._log_determinant(scale_or_tril_other)
if self.full_cov: if self.full_cov:
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False)) trace_part = self._batched_trace_square(
torch.triangular_solve(scale_or_tril, scale_or_tril_other, upper=False)[0]
)
else: else:
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1) trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
@ -65,62 +87,60 @@ class KLProjection(BaseProjection):
def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor: def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
diff = x - y diff = x - y
if self.full_cov: if self.full_cov:
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1) solved = torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0]
return torch.sum(torch.square(solved.squeeze(-1)), dim=-1)
else: else:
return torch.sum(torch.square(diff / scale_or_tril), dim=-1) return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor: def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
if self.full_cov: if self.full_cov:
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1) return 2 * torch.sum(torch.log(torch.diagonal(scale_or_tril, dim1=-2, dim2=-1)), dim=-1)
else: else:
return 2 * torch.log(scale_or_tril).sum(-1) return 2 * torch.sum(torch.log(scale_or_tril), dim=-1)
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: def _batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
return torch.sum(x.pow(2), dim=(-2, -1)) return torch.sum(x ** 2, dim=(-2, -1))
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
if self.full_cov: if self.full_cov:
cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2)) cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2)) old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
else: else:
cov = scale_or_tril.pow(2) cov = scale_or_tril ** 2
old_cov = old_scale_or_tril.pow(2) old_cov = old_scale_or_tril ** 2
mask = cov_part > self.cov_bound mask = cov_part > self.cov_bound
proj_scale_or_tril = torch.zeros_like(scale_or_tril) proj_scale_or_tril = scale_or_tril # Start with original scale
proj_scale_or_tril[~mask] = scale_or_tril[~mask]
if mask.any():
try: if self.full_cov:
if mask.any(): proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
if self.full_cov: is_invalid = torch.isnan(proj_cov.mean(dim=(-2, -1)))
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound) proj_scale_or_tril = torch.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril)
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask mask = mask & ~is_invalid
if is_invalid.any(): chol = torch.linalg.cholesky(proj_cov)
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid] proj_scale_or_tril = torch.where(mask[..., None, None], chol, proj_scale_or_tril)
mask &= ~is_invalid else:
proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
failed_mask = failed_mask.bool() is_invalid = (torch.isnan(proj_cov.mean(dim=-1)) |
if failed_mask.any(): torch.isinf(proj_cov.mean(dim=-1)) |
proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask] (proj_cov.min(dim=-1).values < 0))
else: proj_scale_or_tril = torch.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound) mask = mask & ~is_invalid
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask proj_scale_or_tril = torch.where(mask[..., None], torch.sqrt(proj_cov), scale_or_tril)
if is_invalid.any():
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
mask &= ~is_invalid
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
except Exception as e:
import logging
logging.error('Projection failed, taking old scale_or_tril for projection.')
print("Projection failed, taking old scale_or_tril for projection.")
proj_scale_or_tril = old_scale_or_tril
raise e
return proj_scale_or_tril return proj_scale_or_tril
def _validate_inputs(self, policy_params, old_policy_params):
if self.full_cov:
required_keys = ["loc", "scale_tril"]
else:
required_keys = ["loc", "scale"]
for key in required_keys:
if key not in policy_params or key not in old_policy_params:
raise KeyError(f"Missing required key '{key}' in policy parameters")
class KLProjectionGradFunctionCovOnly(torch.autograd.Function): class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
projection_op = None projection_op = None

View File

@ -1,6 +1,5 @@
import torch import torch
from .base_projection import BaseProjection from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict, Tuple from typing import Dict, Tuple
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor: def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
@ -12,75 +11,151 @@ def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
""" """
return scale_tril return scale_tril
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor], def gaussian_wasserstein_commutative(p: Tuple[torch.Tensor, torch.Tensor],
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]: q: Tuple[torch.Tensor, torch.Tensor],
scale_prec: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
mean, scale_or_sqrt = p mean, scale_or_sqrt = p
mean_other, scale_or_sqrt_other = q mean_other, scale_or_sqrt_other = q
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
if scale_or_sqrt.dim() == mean.dim(): # Diagonal case if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
cov = scale_or_sqrt.pow(2)
cov_other = scale_or_sqrt_other.pow(2)
if scale_prec: if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) # More stable implementation for precision scaling
sqrt_inv_other = 1 / scale_or_sqrt_other scale_part = torch.sum(
c = sqrt_inv_other.pow(2) * cov scale_or_sqrt_other**2 + scale_or_sqrt**2 -
cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1) 2 * scale_or_sqrt_other * scale_or_sqrt,
dim=-1
)
else: else:
cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1) # Standard W2 for diagonal case
scale_part = torch.sum(
scale_or_sqrt_other**2 + scale_or_sqrt**2 -
2 * scale_or_sqrt_other * scale_or_sqrt,
dim=-1
)
else: # Full covariance case else: # Full covariance case
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition # Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
if scale_prec: if scale_prec:
# More stable implementation using triangular solve
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity) sqrt_inv_other = torch.triangular_solve(identity, scale_or_sqrt_other, upper=False)[0]
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2) c = torch.matmul(sqrt_inv_other, scale_or_sqrt)
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt) scale_part = torch.sum(identity**2 + c**2 - 2 * c, dim=(-2, -1))
else: else:
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt) # Standard W2 for full covariance
scale_part = torch.sum(
scale_or_sqrt_other**2 + scale_or_sqrt**2 -
2 * torch.matmul(scale_or_sqrt_other, scale_or_sqrt.transpose(-1, -2)),
dim=(-2, -1)
)
return mean_part, cov_part return mean_part, scale_part
class WassersteinProjection(BaseProjection): class WassersteinProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0,
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False,
contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff,
mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.scale_prec = scale_prec self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean = policy_params["loc"] mean = policy_params["loc"]
old_mean = old_policy_params["loc"] old_mean = old_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]]) # scale_tril is already Cholesky of matrix sqrt
scale_sqrt = policy_params[self.in_keys[1]]
old_scale_sqrt = old_policy_params[self.in_keys[1]]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec) if not self.contextual_std:
scale_sqrt = scale_sqrt[:1]
old_scale_sqrt = old_scale_sqrt[:1]
mean_part, scale_part = self._gaussian_wasserstein(
(mean, scale_sqrt),
(old_mean, old_scale_sqrt)
)
proj_mean = self._mean_projection(mean, old_mean, mean_part) proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part) proj_scale_sqrt = self._scale_projection(scale_sqrt, old_scale_sqrt, scale_part)
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt} if not self.contextual_std:
proj_scale_sqrt = proj_scale_sqrt.expand(mean.shape[0], *proj_scale_sqrt.shape[1:])
return {"loc": proj_mean, self.out_keys[1]: proj_scale_sqrt}
def _gaussian_wasserstein(self, p: Tuple[torch.Tensor, torch.Tensor],
q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
mean, scale_sqrt = p
mean_other, scale_sqrt_other = q
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
if not self.full_cov:
# Diagonal case is simpler
scale_part = torch.sum(
scale_sqrt_other**2 + scale_sqrt**2 -
2 * scale_sqrt_other * scale_sqrt,
dim=-1
)
else:
# Full covariance case uses matrix operations
scale_part = torch.sum(
scale_sqrt_other**2 + scale_sqrt**2 -
2 * torch.matmul(scale_sqrt_other, scale_sqrt.transpose(-1, -2)),
dim=(-2, -1)
)
return mean_part, scale_part
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean = policy_params["loc"] mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"] proj_mean = proj_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]]) scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]]) proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
w2 = mean_part + cov_part mean_part, scale_part = gaussian_wasserstein_commutative(
(mean, scale_or_sqrt),
(proj_mean, proj_scale_or_sqrt),
self.scale_prec
)
w2 = mean_part + scale_part
return w2.mean() * self.trust_region_coeff return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: def _scale_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean """Project scale parameters using multiplicative update."""
norm = torch.sqrt(mean_part)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
diff = scale_or_sqrt - old_scale_or_sqrt return self._diagonal_scale_projection(scale_or_sqrt, old_scale_or_sqrt, scale_part)
norm = torch.sqrt(cov_part)
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
else: # Full covariance case else: # Full covariance case
diff = scale_or_sqrt - old_scale_or_sqrt return self._full_cov_scale_projection(scale_or_sqrt, old_scale_or_sqrt, scale_part)
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt) def _diagonal_scale_projection(self, scale: torch.Tensor, old_scale: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor:
cov_mask = scale_part > self.cov_bound
batch_shape = scale.shape[:-1]
eta = torch.ones(batch_shape, dtype=scale.dtype, device=scale.device)
eta = torch.where(cov_mask,
torch.sqrt(scale_part / self.cov_bound) - 1.,
eta)
eta = torch.maximum(-eta, eta)
new_scale = (scale + eta[..., None] * old_scale) / \
(1. + eta + 1e-16)[..., None]
mask_matrix = cov_mask[..., None].to(scale.dtype)
return torch.where(mask_matrix, new_scale, scale)
def _full_cov_scale_projection(self, scale_sqrt: torch.Tensor, old_scale_sqrt: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor:
cov_mask = scale_part > self.cov_bound
batch_shape = scale_sqrt.shape[:-2]
eta = torch.ones(batch_shape, dtype=scale_sqrt.dtype, device=scale_sqrt.device)
eta = torch.where(cov_mask,
torch.sqrt(scale_part / self.cov_bound) - 1.,
eta)
eta = torch.maximum(-eta, eta)
new_scale = (scale_sqrt + torch.einsum('...,...ij->...ij', eta, old_scale_sqrt)) / \
(1. + eta + 1e-16)[..., None, None]
mask_matrix = cov_mask[..., None, None].to(scale_sqrt.dtype)
return torch.where(mask_matrix, new_scale, scale_sqrt)