Fixing issues with projections
This commit is contained in:
		
							parent
							
								
									71cb8593d9
								
							
						
					
					
						commit
						651ef1522f
					
				| @ -1,16 +1,71 @@ | |||||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||||
| import torch | import torch | ||||||
| from typing import Dict | from torch import nn | ||||||
|  | from typing import Dict, List | ||||||
| 
 | 
 | ||||||
| class BaseProjection(ABC, torch.nn.Module): | class BaseProjection(nn.Module, ABC): | ||||||
|     def __init__(self, in_keys: list[str], out_keys: list[str]): |     def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |         self._validate_in_keys(in_keys) | ||||||
|  |         self._validate_out_keys(out_keys) | ||||||
|         self.in_keys = in_keys |         self.in_keys = in_keys | ||||||
|         self.out_keys = out_keys |         self.out_keys = out_keys | ||||||
|  |         self.trust_region_coeff = trust_region_coeff | ||||||
|  |         self.mean_bound = mean_bound | ||||||
|  |         self.cov_bound = cov_bound | ||||||
|  |         self.full_cov = "scale_tril" in in_keys | ||||||
|  |         self.contextual_std = contextual_std | ||||||
|  | 
 | ||||||
|  |     def _validate_in_keys(self, keys: List[str]): | ||||||
|  |         valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"} | ||||||
|  |         if not set(keys).issubset(valid_keys): | ||||||
|  |             raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}") | ||||||
|  |         if "loc" not in keys or "old_loc" not in keys: | ||||||
|  |             raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys") | ||||||
|  |         if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys): | ||||||
|  |             raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'") | ||||||
|  | 
 | ||||||
|  |     def _validate_out_keys(self, keys: List[str]): | ||||||
|  |         valid_keys = {"loc", "scale", "scale_tril"} | ||||||
|  |         if not set(keys).issubset(valid_keys): | ||||||
|  |             raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}") | ||||||
|  |         if "loc" not in keys: | ||||||
|  |             raise ValueError("'loc' must be included in out_keys") | ||||||
|  |         if "scale" not in keys and "scale_tril" not in keys: | ||||||
|  |             raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys") | ||||||
| 
 | 
 | ||||||
|     @abstractmethod |     @abstractmethod | ||||||
|     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |     @abstractmethod | ||||||
|         return self.project(policy_params, old_policy_params) |     def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def forward(self, tensordict): | ||||||
|  |         policy_params = {} | ||||||
|  |         old_policy_params = {} | ||||||
|  |          | ||||||
|  |         for key in self.in_keys: | ||||||
|  |             if key not in tensordict: | ||||||
|  |                 raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}") | ||||||
|  |              | ||||||
|  |             if key.startswith("old_"): | ||||||
|  |                 old_policy_params[key[4:]] = tensordict[key] | ||||||
|  |             else: | ||||||
|  |                 policy_params[key] = tensordict[key] | ||||||
|  | 
 | ||||||
|  |         projected_params = self.project(policy_params, old_policy_params) | ||||||
|  |         return projected_params | ||||||
|  | 
 | ||||||
|  |     def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor: | ||||||
|  |         if not self.full_cov: | ||||||
|  |             return torch.diag_embed(params["scale"].pow(2)) | ||||||
|  |         else: | ||||||
|  |             return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2)) | ||||||
|  | 
 | ||||||
|  |     def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor: | ||||||
|  |         if not self.full_cov: | ||||||
|  |             return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1)) | ||||||
|  |         else: | ||||||
|  |             return torch.linalg.cholesky(cov) | ||||||
| @ -1,33 +1,34 @@ | |||||||
| import torch | import torch | ||||||
| from .base_projection import BaseProjection | from .base_projection import BaseProjection | ||||||
|  | from tensordict.nn import TensorDictModule | ||||||
| from typing import Dict | from typing import Dict | ||||||
| 
 | 
 | ||||||
| class FrobeniusProjection(BaseProjection): | class FrobeniusProjection(BaseProjection): | ||||||
|     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): |     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): | ||||||
|         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) |         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) | ||||||
|         self.scale_prec = scale_prec |         self.scale_prec = scale_prec | ||||||
| 
 | 
 | ||||||
|     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||||||
|         mean, chol = policy_params["loc"], policy_params["scale_tril"] |         mean = policy_params["loc"] | ||||||
|         old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"] |         old_mean = old_policy_params["loc"] | ||||||
| 
 | 
 | ||||||
|         cov = torch.matmul(chol, chol.transpose(-1, -2)) |         cov = self._calc_covariance(policy_params) | ||||||
|         old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2)) |         old_cov = self._calc_covariance(old_policy_params) | ||||||
| 
 | 
 | ||||||
|         mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) |         mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov)) | ||||||
| 
 | 
 | ||||||
|         proj_mean = self._mean_projection(mean, old_mean, mean_part) |         proj_mean = self._mean_projection(mean, old_mean, mean_part) | ||||||
|         proj_cov = self._cov_projection(cov, old_cov, cov_part) |         proj_cov = self._cov_projection(cov, old_cov, cov_part) | ||||||
| 
 | 
 | ||||||
|         proj_chol = torch.linalg.cholesky(proj_cov) |         scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov) | ||||||
|         return {"loc": proj_mean, "scale_tril": proj_chol} |         return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril} | ||||||
| 
 | 
 | ||||||
|     def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: |     def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: | ||||||
|         mean, chol = policy_params["loc"], policy_params["scale_tril"] |         mean = policy_params["loc"] | ||||||
|         proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"] |         proj_mean = proj_policy_params["loc"] | ||||||
| 
 | 
 | ||||||
|         cov = torch.matmul(chol, chol.transpose(-1, -2)) |         cov = self._calc_covariance(policy_params) | ||||||
|         proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2)) |         proj_cov = self._calc_covariance(proj_policy_params) | ||||||
| 
 | 
 | ||||||
|         mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1) |         mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1) | ||||||
|         cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1)) |         cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1)) | ||||||
|  | |||||||
| @ -3,8 +3,8 @@ from .base_projection import BaseProjection | |||||||
| from typing import Dict | from typing import Dict | ||||||
| 
 | 
 | ||||||
| class IdentityProjection(BaseProjection): | class IdentityProjection(BaseProjection): | ||||||
|     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01): |     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): | ||||||
|         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) |         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) | ||||||
| 
 | 
 | ||||||
|     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||||||
|         return policy_params |         return policy_params | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ import torch | |||||||
| import cpp_projection | import cpp_projection | ||||||
| import numpy as np | import numpy as np | ||||||
| from .base_projection import BaseProjection | from .base_projection import BaseProjection | ||||||
|  | from tensordict.nn import TensorDictModule | ||||||
| from typing import Dict, Tuple, Any | from typing import Dict, Tuple, Any | ||||||
| 
 | 
 | ||||||
| MAX_EVAL = 1000 | MAX_EVAL = 1000 | ||||||
| @ -10,57 +11,65 @@ def get_numpy(tensor): | |||||||
|     return tensor.detach().cpu().numpy() |     return tensor.detach().cpu().numpy() | ||||||
| 
 | 
 | ||||||
| class KLProjection(BaseProjection): | class KLProjection(BaseProjection): | ||||||
|     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True): |     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True): | ||||||
|         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) |         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) | ||||||
|         self.is_diag = is_diag |  | ||||||
|         self.contextual_std = contextual_std |  | ||||||
| 
 | 
 | ||||||
|     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||||||
|         mean, std = policy_params["loc"], policy_params["scale_tril"] |         mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] | ||||||
|         old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"] |         old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]] | ||||||
| 
 | 
 | ||||||
|         mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std)) |         mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril)) | ||||||
| 
 | 
 | ||||||
|         if not self.contextual_std: |         if not self.contextual_std: | ||||||
|             std = std[:1] |             scale_or_tril = scale_or_tril[:1] | ||||||
|             old_std = old_std[:1] |             old_scale_or_tril = old_scale_or_tril[:1] | ||||||
|             cov_part = cov_part[:1] |             cov_part = cov_part[:1] | ||||||
| 
 | 
 | ||||||
|         proj_mean = self._mean_projection(mean, old_mean, mean_part) |         proj_mean = self._mean_projection(mean, old_mean, mean_part) | ||||||
|         proj_std = self._cov_projection(std, old_std, cov_part) |         proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part) | ||||||
| 
 | 
 | ||||||
|         if not self.contextual_std: |         if not self.contextual_std: | ||||||
|             proj_std = proj_std.expand(mean.shape[0], -1, -1) |             proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:]) | ||||||
| 
 | 
 | ||||||
|         return {"loc": proj_mean, "scale_tril": proj_std} |         return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril} | ||||||
| 
 | 
 | ||||||
|     def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: |     def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: | ||||||
|         mean, std = policy_params["loc"], policy_params["scale_tril"] |         mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]] | ||||||
|         proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"] |         proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]] | ||||||
|         kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std))) |         kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril))) | ||||||
|         return kl.mean() * self.trust_region_coeff |         return kl.mean() * self.trust_region_coeff | ||||||
| 
 | 
 | ||||||
|     def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: |     def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||
|         mean, std = p |         mean, scale_or_tril = p | ||||||
|         mean_other, std_other = q |         mean_other, scale_or_tril_other = q | ||||||
|         k = mean.shape[-1] |         k = mean.shape[-1] | ||||||
| 
 | 
 | ||||||
|         maha_part = 0.5 * self._maha(mean, mean_other, std_other) |         maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other) | ||||||
| 
 | 
 | ||||||
|         det_term = self._log_determinant(std) |         det_term = self._log_determinant(scale_or_tril) | ||||||
|         det_term_other = self._log_determinant(std_other) |         det_term_other = self._log_determinant(scale_or_tril_other) | ||||||
|  | 
 | ||||||
|  |         if self.full_cov: | ||||||
|  |             trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False)) | ||||||
|  |         else: | ||||||
|  |             trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1) | ||||||
| 
 | 
 | ||||||
|         trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False)) |  | ||||||
|         cov_part = 0.5 * (trace_part - k + det_term_other - det_term) |         cov_part = 0.5 * (trace_part - k + det_term_other - det_term) | ||||||
| 
 | 
 | ||||||
|         return maha_part, cov_part |         return maha_part, cov_part | ||||||
| 
 | 
 | ||||||
|     def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |     def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor: | ||||||
|         diff = x - y |         diff = x - y | ||||||
|         return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1) |         if self.full_cov: | ||||||
|  |             return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1) | ||||||
|  |         else: | ||||||
|  |             return torch.sum(torch.square(diff / scale_or_tril), dim=-1) | ||||||
| 
 | 
 | ||||||
|     def _log_determinant(self, std: torch.Tensor) -> torch.Tensor: |     def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor: | ||||||
|         return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1) |         if self.full_cov: | ||||||
|  |             return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1) | ||||||
|  |         else: | ||||||
|  |             return 2 * torch.log(scale_or_tril).sum(-1) | ||||||
| 
 | 
 | ||||||
|     def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: |     def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor: | ||||||
|         return torch.sum(x.pow(2), dim=(-2, -1)) |         return torch.sum(x.pow(2), dim=(-2, -1)) | ||||||
| @ -68,49 +77,45 @@ class KLProjection(BaseProjection): | |||||||
|     def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: |     def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: | ||||||
|         return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1) |         return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1) | ||||||
| 
 | 
 | ||||||
|     def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: |     def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: | ||||||
|         cov = torch.matmul(std, std.transpose(-1, -2)) |         if self.full_cov: | ||||||
|         old_cov = torch.matmul(old_std, old_std.transpose(-1, -2)) |             cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2)) | ||||||
| 
 |             old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2)) | ||||||
|         if self.is_diag: |  | ||||||
|             mask = cov_part > self.cov_bound |  | ||||||
|             proj_std = torch.zeros_like(std) |  | ||||||
|             proj_std[~mask] = std[~mask] |  | ||||||
|             try: |  | ||||||
|                 if mask.any(): |  | ||||||
|                     proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1), |  | ||||||
|                                                                          old_cov.diagonal(dim1=-2, dim2=-1), |  | ||||||
|                                                                          self.cov_bound) |  | ||||||
|                     is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask |  | ||||||
|                     if is_invalid.any(): |  | ||||||
|                         proj_std[is_invalid] = old_std[is_invalid] |  | ||||||
|                         mask &= ~is_invalid |  | ||||||
|                     proj_std[mask] = proj_cov[mask].sqrt().diag_embed() |  | ||||||
|             except Exception as e: |  | ||||||
|                 proj_std = old_std |  | ||||||
|         else: |         else: | ||||||
|             try: |             cov = scale_or_tril.pow(2) | ||||||
|  |             old_cov = old_scale_or_tril.pow(2) | ||||||
|  | 
 | ||||||
|         mask = cov_part > self.cov_bound |         mask = cov_part > self.cov_bound | ||||||
|                 proj_std = torch.zeros_like(std) |         proj_scale_or_tril = torch.zeros_like(scale_or_tril) | ||||||
|                 proj_std[~mask] = std[~mask] |         proj_scale_or_tril[~mask] = scale_or_tril[~mask] | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|             if mask.any(): |             if mask.any(): | ||||||
|                     proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound) |                 if self.full_cov: | ||||||
|  |                     proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound) | ||||||
|                     is_invalid = proj_cov.mean([-2, -1]).isnan() & mask |                     is_invalid = proj_cov.mean([-2, -1]).isnan() & mask | ||||||
|                     if is_invalid.any(): |                     if is_invalid.any(): | ||||||
|                         proj_std[is_invalid] = old_std[is_invalid] |                         proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid] | ||||||
|                         mask &= ~is_invalid |                         mask &= ~is_invalid | ||||||
|                     proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) |                     proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask]) | ||||||
|                     failed_mask = failed_mask.bool() |                     failed_mask = failed_mask.bool() | ||||||
|                     if failed_mask.any(): |                     if failed_mask.any(): | ||||||
|                         proj_std[failed_mask] = old_std[failed_mask] |                         proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask] | ||||||
|  |                 else: | ||||||
|  |                     proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound) | ||||||
|  |                     is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask | ||||||
|  |                     if is_invalid.any(): | ||||||
|  |                         proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid] | ||||||
|  |                         mask &= ~is_invalid | ||||||
|  |                     proj_scale_or_tril[mask] = proj_cov[mask].sqrt() | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             import logging |             import logging | ||||||
|                 logging.error('Projection failed, taking old cholesky for projection.') |             logging.error('Projection failed, taking old scale_or_tril for projection.') | ||||||
|                 print("Projection failed, taking old cholesky for projection.") |             print("Projection failed, taking old scale_or_tril for projection.") | ||||||
|                 proj_std = old_std |             proj_scale_or_tril = old_scale_or_tril | ||||||
|             raise e |             raise e | ||||||
| 
 | 
 | ||||||
|         return proj_std |         return proj_scale_or_tril | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class KLProjectionGradFunctionCovOnly(torch.autograd.Function): | class KLProjectionGradFunctionCovOnly(torch.autograd.Function): | ||||||
|  | |||||||
| @ -1,56 +1,86 @@ | |||||||
| import torch | import torch | ||||||
| from .base_projection import BaseProjection | from .base_projection import BaseProjection | ||||||
|  | from tensordict.nn import TensorDictModule | ||||||
| from typing import Dict, Tuple | from typing import Dict, Tuple | ||||||
| 
 | 
 | ||||||
|  | def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor: | ||||||
|  |     """ | ||||||
|  |     'Converts' scale_tril to scale_sqrt. | ||||||
|  |      | ||||||
|  |     For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition. | ||||||
|  |     But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root. | ||||||
|  |     """ | ||||||
|  |     return scale_tril | ||||||
|  | 
 | ||||||
| def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor], | def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor], | ||||||
|                                      q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]: |                                      q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||
|     mean, sqrt = p |     mean, scale_or_sqrt = p | ||||||
|     mean_other, sqrt_other = q |     mean_other, scale_or_sqrt_other = q | ||||||
| 
 | 
 | ||||||
|     mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) |     mean_part = torch.sum(torch.square(mean - mean_other), dim=-1) | ||||||
| 
 | 
 | ||||||
|     cov = torch.matmul(sqrt, sqrt.transpose(-1, -2)) |     if scale_or_sqrt.dim() == mean.dim():  # Diagonal case | ||||||
|     cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2)) |         cov = scale_or_sqrt.pow(2) | ||||||
| 
 |         cov_other = scale_or_sqrt_other.pow(2) | ||||||
|         if scale_prec: |         if scale_prec: | ||||||
|         identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device) |             identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) | ||||||
|         sqrt_inv_other = torch.linalg.solve(sqrt_other, identity) |             sqrt_inv_other = 1 / scale_or_sqrt_other | ||||||
|         c = sqrt_inv_other @ cov @ sqrt_inv_other |             c = sqrt_inv_other.pow(2) * cov | ||||||
|         cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt) |             cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1) | ||||||
|         else: |         else: | ||||||
|         cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt) |             cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1) | ||||||
|  |     else:  # Full covariance case | ||||||
|  |         # Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition | ||||||
|  |         cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2)) | ||||||
|  |         cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2)) | ||||||
|  |         if scale_prec: | ||||||
|  |             identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device) | ||||||
|  |             sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity) | ||||||
|  |             c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2) | ||||||
|  |             cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt) | ||||||
|  |         else: | ||||||
|  |             cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt) | ||||||
| 
 | 
 | ||||||
|     return mean_part, cov_part |     return mean_part, cov_part | ||||||
| 
 | 
 | ||||||
| class WassersteinProjection(BaseProjection): | class WassersteinProjection(BaseProjection): | ||||||
|     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False): |     def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True): | ||||||
|         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound) |         super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std) | ||||||
|         self.scale_prec = scale_prec |         self.scale_prec = scale_prec | ||||||
| 
 | 
 | ||||||
|     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |     def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||||||
|         mean, sqrt = policy_params["loc"], policy_params["scale_tril"] |         mean = policy_params["loc"] | ||||||
|         old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"] |         old_mean = old_policy_params["loc"] | ||||||
|  |         scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]]) | ||||||
|  |         old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]]) | ||||||
| 
 | 
 | ||||||
|         mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec) |         mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec) | ||||||
| 
 | 
 | ||||||
|         proj_mean = self._mean_projection(mean, old_mean, mean_part) |         proj_mean = self._mean_projection(mean, old_mean, mean_part) | ||||||
|         proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part) |         proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part) | ||||||
| 
 | 
 | ||||||
|         return {"loc": proj_mean, "scale_tril": proj_sqrt} |         return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt} | ||||||
| 
 | 
 | ||||||
|     def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: |     def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor: | ||||||
|         mean, sqrt = policy_params["loc"], policy_params["scale_tril"] |         mean = policy_params["loc"] | ||||||
|         proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"] |         proj_mean = proj_policy_params["loc"] | ||||||
|         mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec) |         scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]]) | ||||||
|  |         proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]]) | ||||||
|  |         mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec) | ||||||
|         w2 = mean_part + cov_part |         w2 = mean_part + cov_part | ||||||
|         return w2.mean() * self.trust_region_coeff |         return w2.mean() * self.trust_region_coeff | ||||||
| 
 | 
 | ||||||
|     def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: |     def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor: | ||||||
|         diff = mean - old_mean |         diff = mean - old_mean | ||||||
|         norm = torch.norm(diff, dim=-1, keepdim=True) |         norm = torch.sqrt(mean_part) | ||||||
|         return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean) |         return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean) | ||||||
| 
 | 
 | ||||||
|     def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: |     def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor: | ||||||
|         diff = sqrt - old_sqrt |         if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2:  # Diagonal case | ||||||
|  |             diff = scale_or_sqrt - old_scale_or_sqrt | ||||||
|  |             norm = torch.sqrt(cov_part) | ||||||
|  |             return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt) | ||||||
|  |         else:  # Full covariance case | ||||||
|  |             diff = scale_or_sqrt - old_scale_or_sqrt | ||||||
|             norm = torch.norm(diff, dim=(-2, -1), keepdim=True) |             norm = torch.norm(diff, dim=(-2, -1), keepdim=True) | ||||||
|         return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt) |             return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user