Port soem fixes / additions learned from implementign itpal_jax
This commit is contained in:
parent
3816adef9a
commit
ecf0b72e88
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
class BaseProjection(nn.Module, ABC):
|
||||
def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||
@ -68,4 +68,25 @@ class BaseProjection(nn.Module, ABC):
|
||||
if not self.full_cov:
|
||||
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
|
||||
else:
|
||||
return torch.linalg.cholesky(cov)
|
||||
return torch.linalg.cholesky(cov)
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
"""Project mean based on the Mahalanobis objective and trust region.
|
||||
|
||||
Args:
|
||||
mean: Current mean vectors
|
||||
old_mean: Old mean vectors
|
||||
mean_part: Mahalanobis/Euclidean distance between the two mean vectors
|
||||
|
||||
Returns:
|
||||
Projected mean that satisfies the trust region
|
||||
"""
|
||||
mask = mean_part > self.mean_bound
|
||||
omega = torch.ones_like(mean_part, device=mean.device)
|
||||
omega = torch.where(mask, torch.sqrt(mean_part / self.mean_bound) - 1., omega)
|
||||
omega = torch.maximum(-omega, omega)[..., None]
|
||||
|
||||
# Use matrix operations instead of boolean indexing
|
||||
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
|
||||
mask_matrix = mask[..., None].to(mean.dtype)
|
||||
return mask_matrix * m + (1 - mask_matrix) * mean
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from tensordict.nn import TensorDictModule
|
||||
from typing import Dict
|
||||
from typing import Dict, Tuple
|
||||
|
||||
class FrobeniusProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||
@ -12,16 +12,23 @@ class FrobeniusProjection(BaseProjection):
|
||||
mean = policy_params["loc"]
|
||||
old_mean = old_policy_params["loc"]
|
||||
|
||||
# Convert to covariance representation
|
||||
cov = self._calc_covariance(policy_params)
|
||||
old_cov = self._calc_covariance(old_policy_params)
|
||||
|
||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||
if not self.contextual_std:
|
||||
cov = cov[:1]
|
||||
old_cov = old_cov[:1]
|
||||
|
||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||
|
||||
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
|
||||
return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril}
|
||||
scale_or_tril = self._calc_scale_or_scale_tril(proj_cov)
|
||||
if not self.contextual_std:
|
||||
scale_or_tril = scale_or_tril.expand(mean.shape[0], *scale_or_tril.shape[1:])
|
||||
|
||||
return {"loc": proj_mean, self.out_keys[1]: scale_or_tril}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = policy_params["loc"]
|
||||
@ -35,34 +42,48 @@ class FrobeniusProjection(BaseProjection):
|
||||
|
||||
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
|
||||
|
||||
def _gaussian_frobenius(self, p, q):
|
||||
def _gaussian_frobenius(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean, cov = p
|
||||
old_mean, old_cov = q
|
||||
|
||||
if self.scale_prec:
|
||||
prec_old = torch.inverse(old_cov)
|
||||
mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1)
|
||||
cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1]
|
||||
if self.full_cov:
|
||||
# Use triangular solve instead of inverse for stability
|
||||
diff = mean - old_mean
|
||||
solved = torch.triangular_solve(diff.unsqueeze(-1), old_cov, upper=False)[0].squeeze(-1)
|
||||
mean_part = torch.sum(torch.square(solved), dim=-1)
|
||||
else:
|
||||
# Diagonal case - direct division is stable
|
||||
mean_part = torch.sum(torch.square((mean - old_mean) / old_cov), dim=-1)
|
||||
else:
|
||||
mean_part = torch.sum(torch.square(mean - old_mean), dim=-1)
|
||||
cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1))
|
||||
|
||||
# Covariance part
|
||||
if self.full_cov:
|
||||
diff = old_cov - cov
|
||||
cov_part = torch.sum(torch.square(diff), dim=(-2, -1))
|
||||
else:
|
||||
diff = old_cov - cov
|
||||
cov_part = torch.sum(torch.square(diff), dim=-1)
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
diff = mean - old_mean
|
||||
norm = torch.sqrt(mean_part)
|
||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
|
||||
|
||||
def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
batch_shape = cov.shape[:-2]
|
||||
batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
|
||||
cov_mask = cov_part > self.cov_bound
|
||||
|
||||
eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device)
|
||||
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1.
|
||||
eta = torch.max(-eta, eta)
|
||||
eta = torch.where(cov_mask, torch.sqrt(cov_part / self.cov_bound) - 1., eta)
|
||||
eta = torch.maximum(-eta, eta)
|
||||
|
||||
new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
||||
proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov)
|
||||
if self.full_cov:
|
||||
new_cov = (cov + torch.einsum('...,...ij->...ij', eta, old_cov)) / \
|
||||
(1. + eta + 1e-16)[..., None, None]
|
||||
mask_matrix = cov_mask[..., None, None].to(cov.dtype)
|
||||
proj_cov = torch.where(mask_matrix, new_cov, cov)
|
||||
else:
|
||||
new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
|
||||
mask_matrix = cov_mask[..., None].to(cov.dtype)
|
||||
proj_cov = torch.where(mask_matrix, new_cov, cov)
|
||||
|
||||
return proj_cov
|
@ -19,8 +19,17 @@ class KLProjection(BaseProjection):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
|
||||
self._validate_inputs(policy_params, old_policy_params)
|
||||
|
||||
mean = policy_params["loc"]
|
||||
old_mean = old_policy_params["loc"]
|
||||
|
||||
if self.full_cov:
|
||||
scale_or_tril = policy_params["scale_tril"]
|
||||
old_scale_or_tril = old_policy_params["scale_tril"]
|
||||
else:
|
||||
scale_or_tril = policy_params["scale"]
|
||||
old_scale_or_tril = old_policy_params["scale"]
|
||||
|
||||
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
|
||||
|
||||
@ -35,11 +44,22 @@ class KLProjection(BaseProjection):
|
||||
if not self.contextual_std:
|
||||
proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
|
||||
|
||||
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril}
|
||||
if self.full_cov:
|
||||
return {"loc": proj_mean, "scale_tril": proj_scale_or_tril}
|
||||
else:
|
||||
return {"loc": proj_mean, "scale": proj_scale_or_tril}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]]
|
||||
mean = policy_params["loc"]
|
||||
proj_mean = proj_policy_params["loc"]
|
||||
|
||||
if self.full_cov:
|
||||
scale_or_tril = policy_params["scale_tril"]
|
||||
proj_scale_or_tril = proj_policy_params["scale_tril"]
|
||||
else:
|
||||
scale_or_tril = policy_params["scale"]
|
||||
proj_scale_or_tril = proj_policy_params["scale"]
|
||||
|
||||
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||
return kl.mean() * self.trust_region_coeff
|
||||
|
||||
@ -54,7 +74,9 @@ class KLProjection(BaseProjection):
|
||||
det_term_other = self._log_determinant(scale_or_tril_other)
|
||||
|
||||
if self.full_cov:
|
||||
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False))
|
||||
trace_part = self._batched_trace_square(
|
||||
torch.triangular_solve(scale_or_tril, scale_or_tril_other, upper=False)[0]
|
||||
)
|
||||
else:
|
||||
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
|
||||
|
||||
@ -65,62 +87,60 @@ class KLProjection(BaseProjection):
|
||||
def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||
diff = x - y
|
||||
if self.full_cov:
|
||||
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1)
|
||||
solved = torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0]
|
||||
return torch.sum(torch.square(solved.squeeze(-1)), dim=-1)
|
||||
else:
|
||||
return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
|
||||
|
||||
def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||
if self.full_cov:
|
||||
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
||||
return 2 * torch.sum(torch.log(torch.diagonal(scale_or_tril, dim1=-2, dim2=-1)), dim=-1)
|
||||
else:
|
||||
return 2 * torch.log(scale_or_tril).sum(-1)
|
||||
return 2 * torch.sum(torch.log(scale_or_tril), dim=-1)
|
||||
|
||||
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(x.pow(2), dim=(-2, -1))
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
|
||||
def _batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(x ** 2, dim=(-2, -1))
|
||||
|
||||
def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
if self.full_cov:
|
||||
cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
|
||||
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
|
||||
else:
|
||||
cov = scale_or_tril.pow(2)
|
||||
old_cov = old_scale_or_tril.pow(2)
|
||||
cov = scale_or_tril ** 2
|
||||
old_cov = old_scale_or_tril ** 2
|
||||
|
||||
mask = cov_part > self.cov_bound
|
||||
proj_scale_or_tril = torch.zeros_like(scale_or_tril)
|
||||
proj_scale_or_tril[~mask] = scale_or_tril[~mask]
|
||||
|
||||
try:
|
||||
if mask.any():
|
||||
if self.full_cov:
|
||||
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound)
|
||||
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
||||
if is_invalid.any():
|
||||
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||
mask &= ~is_invalid
|
||||
proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
||||
failed_mask = failed_mask.bool()
|
||||
if failed_mask.any():
|
||||
proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask]
|
||||
else:
|
||||
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound)
|
||||
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
|
||||
if is_invalid.any():
|
||||
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||
mask &= ~is_invalid
|
||||
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error('Projection failed, taking old scale_or_tril for projection.')
|
||||
print("Projection failed, taking old scale_or_tril for projection.")
|
||||
proj_scale_or_tril = old_scale_or_tril
|
||||
raise e
|
||||
proj_scale_or_tril = scale_or_tril # Start with original scale
|
||||
|
||||
if mask.any():
|
||||
if self.full_cov:
|
||||
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
||||
is_invalid = torch.isnan(proj_cov.mean(dim=(-2, -1)))
|
||||
proj_scale_or_tril = torch.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril)
|
||||
mask = mask & ~is_invalid
|
||||
chol = torch.linalg.cholesky(proj_cov)
|
||||
proj_scale_or_tril = torch.where(mask[..., None, None], chol, proj_scale_or_tril)
|
||||
else:
|
||||
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
||||
is_invalid = (torch.isnan(proj_cov.mean(dim=-1)) |
|
||||
torch.isinf(proj_cov.mean(dim=-1)) |
|
||||
(proj_cov.min(dim=-1).values < 0))
|
||||
proj_scale_or_tril = torch.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
|
||||
mask = mask & ~is_invalid
|
||||
proj_scale_or_tril = torch.where(mask[..., None], torch.sqrt(proj_cov), scale_or_tril)
|
||||
|
||||
return proj_scale_or_tril
|
||||
|
||||
def _validate_inputs(self, policy_params, old_policy_params):
|
||||
if self.full_cov:
|
||||
required_keys = ["loc", "scale_tril"]
|
||||
else:
|
||||
required_keys = ["loc", "scale"]
|
||||
|
||||
for key in required_keys:
|
||||
if key not in policy_params or key not in old_policy_params:
|
||||
raise KeyError(f"Missing required key '{key}' in policy parameters")
|
||||
|
||||
|
||||
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
|
||||
projection_op = None
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from tensordict.nn import TensorDictModule
|
||||
from typing import Dict, Tuple
|
||||
|
||||
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
|
||||
@ -12,75 +11,151 @@ def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
return scale_tril
|
||||
|
||||
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
|
||||
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def gaussian_wasserstein_commutative(p: Tuple[torch.Tensor, torch.Tensor],
|
||||
q: Tuple[torch.Tensor, torch.Tensor],
|
||||
scale_prec: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean, scale_or_sqrt = p
|
||||
mean_other, scale_or_sqrt_other = q
|
||||
|
||||
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
||||
|
||||
if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
|
||||
cov = scale_or_sqrt.pow(2)
|
||||
cov_other = scale_or_sqrt_other.pow(2)
|
||||
if scale_prec:
|
||||
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||
sqrt_inv_other = 1 / scale_or_sqrt_other
|
||||
c = sqrt_inv_other.pow(2) * cov
|
||||
cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1)
|
||||
# More stable implementation for precision scaling
|
||||
scale_part = torch.sum(
|
||||
scale_or_sqrt_other**2 + scale_or_sqrt**2 -
|
||||
2 * scale_or_sqrt_other * scale_or_sqrt,
|
||||
dim=-1
|
||||
)
|
||||
else:
|
||||
cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1)
|
||||
# Standard W2 for diagonal case
|
||||
scale_part = torch.sum(
|
||||
scale_or_sqrt_other**2 + scale_or_sqrt**2 -
|
||||
2 * scale_or_sqrt_other * scale_or_sqrt,
|
||||
dim=-1
|
||||
)
|
||||
else: # Full covariance case
|
||||
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
|
||||
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
|
||||
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
|
||||
if scale_prec:
|
||||
# More stable implementation using triangular solve
|
||||
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity)
|
||||
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2)
|
||||
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
|
||||
sqrt_inv_other = torch.triangular_solve(identity, scale_or_sqrt_other, upper=False)[0]
|
||||
c = torch.matmul(sqrt_inv_other, scale_or_sqrt)
|
||||
scale_part = torch.sum(identity**2 + c**2 - 2 * c, dim=(-2, -1))
|
||||
else:
|
||||
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
|
||||
# Standard W2 for full covariance
|
||||
scale_part = torch.sum(
|
||||
scale_or_sqrt_other**2 + scale_or_sqrt**2 -
|
||||
2 * torch.matmul(scale_or_sqrt_other, scale_or_sqrt.transpose(-1, -2)),
|
||||
dim=(-2, -1)
|
||||
)
|
||||
|
||||
return mean_part, cov_part
|
||||
return mean_part, scale_part
|
||||
|
||||
class WassersteinProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0,
|
||||
mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False,
|
||||
contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff,
|
||||
mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||
self.scale_prec = scale_prec
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean = policy_params["loc"]
|
||||
old_mean = old_policy_params["loc"]
|
||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]])
|
||||
|
||||
# scale_tril is already Cholesky of matrix sqrt
|
||||
scale_sqrt = policy_params[self.in_keys[1]]
|
||||
old_scale_sqrt = old_policy_params[self.in_keys[1]]
|
||||
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec)
|
||||
if not self.contextual_std:
|
||||
scale_sqrt = scale_sqrt[:1]
|
||||
old_scale_sqrt = old_scale_sqrt[:1]
|
||||
|
||||
mean_part, scale_part = self._gaussian_wasserstein(
|
||||
(mean, scale_sqrt),
|
||||
(old_mean, old_scale_sqrt)
|
||||
)
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
|
||||
proj_scale_sqrt = self._scale_projection(scale_sqrt, old_scale_sqrt, scale_part)
|
||||
|
||||
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt}
|
||||
if not self.contextual_std:
|
||||
proj_scale_sqrt = proj_scale_sqrt.expand(mean.shape[0], *proj_scale_sqrt.shape[1:])
|
||||
|
||||
return {"loc": proj_mean, self.out_keys[1]: proj_scale_sqrt}
|
||||
|
||||
def _gaussian_wasserstein(self, p: Tuple[torch.Tensor, torch.Tensor],
|
||||
q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean, scale_sqrt = p
|
||||
mean_other, scale_sqrt_other = q
|
||||
|
||||
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
||||
|
||||
if not self.full_cov:
|
||||
# Diagonal case is simpler
|
||||
scale_part = torch.sum(
|
||||
scale_sqrt_other**2 + scale_sqrt**2 -
|
||||
2 * scale_sqrt_other * scale_sqrt,
|
||||
dim=-1
|
||||
)
|
||||
else:
|
||||
# Full covariance case uses matrix operations
|
||||
scale_part = torch.sum(
|
||||
scale_sqrt_other**2 + scale_sqrt**2 -
|
||||
2 * torch.matmul(scale_sqrt_other, scale_sqrt.transpose(-1, -2)),
|
||||
dim=(-2, -1)
|
||||
)
|
||||
|
||||
return mean_part, scale_part
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = policy_params["loc"]
|
||||
proj_mean = proj_policy_params["loc"]
|
||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
|
||||
w2 = mean_part + cov_part
|
||||
|
||||
mean_part, scale_part = gaussian_wasserstein_commutative(
|
||||
(mean, scale_or_sqrt),
|
||||
(proj_mean, proj_scale_or_sqrt),
|
||||
self.scale_prec
|
||||
)
|
||||
w2 = mean_part + scale_part
|
||||
return w2.mean() * self.trust_region_coeff
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
diff = mean - old_mean
|
||||
norm = torch.sqrt(mean_part)
|
||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
|
||||
|
||||
def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
def _scale_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor:
|
||||
"""Project scale parameters using multiplicative update."""
|
||||
if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
|
||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||
norm = torch.sqrt(cov_part)
|
||||
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
|
||||
return self._diagonal_scale_projection(scale_or_sqrt, old_scale_or_sqrt, scale_part)
|
||||
else: # Full covariance case
|
||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
|
||||
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt)
|
||||
return self._full_cov_scale_projection(scale_or_sqrt, old_scale_or_sqrt, scale_part)
|
||||
|
||||
def _diagonal_scale_projection(self, scale: torch.Tensor, old_scale: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor:
|
||||
cov_mask = scale_part > self.cov_bound
|
||||
|
||||
batch_shape = scale.shape[:-1]
|
||||
eta = torch.ones(batch_shape, dtype=scale.dtype, device=scale.device)
|
||||
eta = torch.where(cov_mask,
|
||||
torch.sqrt(scale_part / self.cov_bound) - 1.,
|
||||
eta)
|
||||
eta = torch.maximum(-eta, eta)
|
||||
|
||||
new_scale = (scale + eta[..., None] * old_scale) / \
|
||||
(1. + eta + 1e-16)[..., None]
|
||||
mask_matrix = cov_mask[..., None].to(scale.dtype)
|
||||
return torch.where(mask_matrix, new_scale, scale)
|
||||
|
||||
def _full_cov_scale_projection(self, scale_sqrt: torch.Tensor, old_scale_sqrt: torch.Tensor, scale_part: torch.Tensor) -> torch.Tensor:
|
||||
cov_mask = scale_part > self.cov_bound
|
||||
|
||||
batch_shape = scale_sqrt.shape[:-2]
|
||||
eta = torch.ones(batch_shape, dtype=scale_sqrt.dtype, device=scale_sqrt.device)
|
||||
eta = torch.where(cov_mask,
|
||||
torch.sqrt(scale_part / self.cov_bound) - 1.,
|
||||
eta)
|
||||
eta = torch.maximum(-eta, eta)
|
||||
|
||||
new_scale = (scale_sqrt + torch.einsum('...,...ij->...ij', eta, old_scale_sqrt)) / \
|
||||
(1. + eta + 1e-16)[..., None, None]
|
||||
mask_matrix = cov_mask[..., None, None].to(scale_sqrt.dtype)
|
||||
return torch.where(mask_matrix, new_scale, scale_sqrt)
|
Loading…
Reference in New Issue
Block a user