New projection impls
This commit is contained in:
		
							parent
							
								
									d29417187f
								
							
						
					
					
						commit
						5fc4b30ea8
					
				@ -1,6 +1,20 @@
 | 
				
			|||||||
try:
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
    import cpp_projection
 | 
					from .identity_projection import IdentityProjection
 | 
				
			||||||
except ModuleNotFoundError:
 | 
					from .kl_projection import KLProjection
 | 
				
			||||||
    from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
 | 
					from .wasserstein_projection import WassersteinProjection
 | 
				
			||||||
else:
 | 
					from .frobenius_projection import FrobeniusProjection
 | 
				
			||||||
    from .kl_projection_layer import KLProjectionLayer
 | 
					
 | 
				
			||||||
 | 
					def get_projection(projection_name: str):
 | 
				
			||||||
 | 
					    projections = {
 | 
				
			||||||
 | 
					        "identity_projection": IdentityProjection,
 | 
				
			||||||
 | 
					        "kl_projection": KLProjection,
 | 
				
			||||||
 | 
					        "wasserstein_projection": WassersteinProjection,
 | 
				
			||||||
 | 
					        "frobenius_projection": FrobeniusProjection,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    projection = projections.get(projection_name.lower())
 | 
				
			||||||
 | 
					    if projection is None:
 | 
				
			||||||
 | 
					        raise ValueError(f"Unknown projection: {projection_name}")
 | 
				
			||||||
 | 
					    return projection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ["BaseProjection", "IdentityProjection", "KLProjection", "WassersteinProjection", "FrobeniusProjection", "get_projection"]
 | 
				
			||||||
							
								
								
									
										16
									
								
								fancy_rl/projections/base_projection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								fancy_rl/projections/base_projection.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from typing import Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaseProjection(ABC, torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_keys: list[str], out_keys: list[str]):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.in_keys = in_keys
 | 
				
			||||||
 | 
					        self.out_keys = out_keys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
 | 
					        return self.project(policy_params, old_policy_params)
 | 
				
			||||||
							
								
								
									
										67
									
								
								fancy_rl/projections/frob_projection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								fancy_rl/projections/frob_projection.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,67 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
 | 
					from typing import Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FrobeniusProjection(BaseProjection):
 | 
				
			||||||
 | 
					    def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
 | 
				
			||||||
 | 
					        super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
 | 
				
			||||||
 | 
					        self.scale_prec = scale_prec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
 | 
					        mean, chol = policy_params["loc"], policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cov = torch.matmul(chol, chol.transpose(-1, -2))
 | 
				
			||||||
 | 
					        old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        proj_mean = self._mean_projection(mean, old_mean, mean_part)
 | 
				
			||||||
 | 
					        proj_cov = self._cov_projection(cov, old_cov, cov_part)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        proj_chol = torch.linalg.cholesky(proj_cov)
 | 
				
			||||||
 | 
					        return {"loc": proj_mean, "scale_tril": proj_chol}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
 | 
				
			||||||
 | 
					        mean, chol = policy_params["loc"], policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cov = torch.matmul(chol, chol.transpose(-1, -2))
 | 
				
			||||||
 | 
					        proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
 | 
				
			||||||
 | 
					        cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return (mean_diff + cov_diff).mean() * self.trust_region_coeff
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _gaussian_frobenius(self, p, q):
 | 
				
			||||||
 | 
					        mean, cov = p
 | 
				
			||||||
 | 
					        old_mean, old_cov = q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.scale_prec:
 | 
				
			||||||
 | 
					            prec_old = torch.inverse(old_cov)
 | 
				
			||||||
 | 
					            mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1)
 | 
				
			||||||
 | 
					            cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            mean_part = torch.sum(torch.square(mean - old_mean), dim=-1)
 | 
				
			||||||
 | 
					            cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return mean_part, cov_part
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        diff = mean - old_mean
 | 
				
			||||||
 | 
					        norm = torch.sqrt(mean_part)
 | 
				
			||||||
 | 
					        return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        batch_shape = cov.shape[:-2]
 | 
				
			||||||
 | 
					        cov_mask = cov_part > self.cov_bound
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device)
 | 
				
			||||||
 | 
					        eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1.
 | 
				
			||||||
 | 
					        eta = torch.max(-eta, eta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
 | 
				
			||||||
 | 
					        proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return proj_cov
 | 
				
			||||||
							
								
								
									
										13
									
								
								fancy_rl/projections/identity_projection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								fancy_rl/projections/identity_projection.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
 | 
					from typing import Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IdentityProjection(BaseProjection):
 | 
				
			||||||
 | 
					    def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
 | 
				
			||||||
 | 
					        super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
 | 
					        return policy_params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
 | 
				
			||||||
 | 
					        return torch.tensor(0.0, device=next(iter(policy_params.values())).device)
 | 
				
			||||||
							
								
								
									
										199
									
								
								fancy_rl/projections/kl_projection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								fancy_rl/projections/kl_projection.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,199 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import cpp_projection
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
 | 
					from typing import Dict, Tuple, Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MAX_EVAL = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_numpy(tensor):
 | 
				
			||||||
 | 
					    return tensor.detach().cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KLProjection(BaseProjection):
 | 
				
			||||||
 | 
					    def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
 | 
				
			||||||
 | 
					        super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
 | 
				
			||||||
 | 
					        self.is_diag = is_diag
 | 
				
			||||||
 | 
					        self.contextual_std = contextual_std
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
 | 
					        mean, std = policy_params["loc"], policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.contextual_std:
 | 
				
			||||||
 | 
					            std = std[:1]
 | 
				
			||||||
 | 
					            old_std = old_std[:1]
 | 
				
			||||||
 | 
					            cov_part = cov_part[:1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        proj_mean = self._mean_projection(mean, old_mean, mean_part)
 | 
				
			||||||
 | 
					        proj_std = self._cov_projection(std, old_std, cov_part)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.contextual_std:
 | 
				
			||||||
 | 
					            proj_std = proj_std.expand(mean.shape[0], -1, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return {"loc": proj_mean, "scale_tril": proj_std}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
 | 
				
			||||||
 | 
					        mean, std = policy_params["loc"], policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
 | 
				
			||||||
 | 
					        return kl.mean() * self.trust_region_coeff
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					        mean, std = p
 | 
				
			||||||
 | 
					        mean_other, std_other = q
 | 
				
			||||||
 | 
					        k = mean.shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        maha_part = 0.5 * self._maha(mean, mean_other, std_other)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        det_term = self._log_determinant(std)
 | 
				
			||||||
 | 
					        det_term_other = self._log_determinant(std_other)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
 | 
				
			||||||
 | 
					        cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return maha_part, cov_part
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        diff = x - y
 | 
				
			||||||
 | 
					        return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        return torch.sum(x.pow(2), dim=(-2, -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        cov = torch.matmul(std, std.transpose(-1, -2))
 | 
				
			||||||
 | 
					        old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.is_diag:
 | 
				
			||||||
 | 
					            mask = cov_part > self.cov_bound
 | 
				
			||||||
 | 
					            proj_std = torch.zeros_like(std)
 | 
				
			||||||
 | 
					            proj_std[~mask] = std[~mask]
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                if mask.any():
 | 
				
			||||||
 | 
					                    proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
 | 
				
			||||||
 | 
					                                                                         old_cov.diagonal(dim1=-2, dim2=-1),
 | 
				
			||||||
 | 
					                                                                         self.cov_bound)
 | 
				
			||||||
 | 
					                    is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
 | 
				
			||||||
 | 
					                    if is_invalid.any():
 | 
				
			||||||
 | 
					                        proj_std[is_invalid] = old_std[is_invalid]
 | 
				
			||||||
 | 
					                        mask &= ~is_invalid
 | 
				
			||||||
 | 
					                    proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
 | 
				
			||||||
 | 
					            except Exception as e:
 | 
				
			||||||
 | 
					                proj_std = old_std
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                mask = cov_part > self.cov_bound
 | 
				
			||||||
 | 
					                proj_std = torch.zeros_like(std)
 | 
				
			||||||
 | 
					                proj_std[~mask] = std[~mask]
 | 
				
			||||||
 | 
					                if mask.any():
 | 
				
			||||||
 | 
					                    proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
 | 
				
			||||||
 | 
					                    is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
 | 
				
			||||||
 | 
					                    if is_invalid.any():
 | 
				
			||||||
 | 
					                        proj_std[is_invalid] = old_std[is_invalid]
 | 
				
			||||||
 | 
					                        mask &= ~is_invalid
 | 
				
			||||||
 | 
					                    proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
 | 
				
			||||||
 | 
					                    failed_mask = failed_mask.bool()
 | 
				
			||||||
 | 
					                    if failed_mask.any():
 | 
				
			||||||
 | 
					                        proj_std[failed_mask] = old_std[failed_mask]
 | 
				
			||||||
 | 
					            except Exception as e:
 | 
				
			||||||
 | 
					                import logging
 | 
				
			||||||
 | 
					                logging.error('Projection failed, taking old cholesky for projection.')
 | 
				
			||||||
 | 
					                print("Projection failed, taking old cholesky for projection.")
 | 
				
			||||||
 | 
					                proj_std = old_std
 | 
				
			||||||
 | 
					                raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return proj_std
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
 | 
				
			||||||
 | 
					    projection_op = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
 | 
				
			||||||
 | 
					        if not KLProjectionGradFunctionCovOnly.projection_op:
 | 
				
			||||||
 | 
					            KLProjectionGradFunctionCovOnly.projection_op = \
 | 
				
			||||||
 | 
					                cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
 | 
				
			||||||
 | 
					        return KLProjectionGradFunctionCovOnly.projection_op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
 | 
				
			||||||
 | 
					        cov, chol, old_chol, eps_cov = args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_shape = cov.shape[0]
 | 
				
			||||||
 | 
					        dim = cov.shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cov_np = get_numpy(cov)
 | 
				
			||||||
 | 
					        chol_np = get_numpy(chol)
 | 
				
			||||||
 | 
					        old_chol_np = get_numpy(old_chol)
 | 
				
			||||||
 | 
					        eps = get_numpy(eps_cov) * np.ones(batch_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p_op = KLProjectionGradFunctionCovOnly.get_projection_op(batch_shape, dim)
 | 
				
			||||||
 | 
					        ctx.proj = p_op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        proj_std = p_op.forward(eps, old_chol_np, chol_np, cov_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return cov.new(proj_std)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def backward(ctx: Any, *grad_outputs: Any) -> Any:
 | 
				
			||||||
 | 
					        projection_op = ctx.proj
 | 
				
			||||||
 | 
					        d_cov, = grad_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        d_cov_np = get_numpy(d_cov)
 | 
				
			||||||
 | 
					        d_cov_np = np.atleast_2d(d_cov_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        df_stds = projection_op.backward(d_cov_np)
 | 
				
			||||||
 | 
					        df_stds = np.atleast_2d(df_stds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        df_stds = d_cov.new(df_stds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return df_stds, None, None, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KLProjectionGradFunctionDiagCovOnly(torch.autograd.Function):
 | 
				
			||||||
 | 
					    projection_op = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
 | 
				
			||||||
 | 
					        if not KLProjectionGradFunctionDiagCovOnly.projection_op:
 | 
				
			||||||
 | 
					            KLProjectionGradFunctionDiagCovOnly.projection_op = \
 | 
				
			||||||
 | 
					                cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
 | 
				
			||||||
 | 
					        return KLProjectionGradFunctionDiagCovOnly.projection_op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
 | 
				
			||||||
 | 
					        cov, old_cov, eps_cov = args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_shape = cov.shape[0]
 | 
				
			||||||
 | 
					        dim = cov.shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cov_np = get_numpy(cov)
 | 
				
			||||||
 | 
					        old_cov_np = get_numpy(old_cov)
 | 
				
			||||||
 | 
					        eps = get_numpy(eps_cov) * np.ones(batch_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim)
 | 
				
			||||||
 | 
					        ctx.proj = p_op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        proj_std = p_op.forward(eps, old_cov_np, cov_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return cov.new(proj_std)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def backward(ctx: Any, *grad_outputs: Any) -> Any:
 | 
				
			||||||
 | 
					        projection_op = ctx.proj
 | 
				
			||||||
 | 
					        d_std, = grad_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        d_cov_np = get_numpy(d_std)
 | 
				
			||||||
 | 
					        d_cov_np = np.atleast_2d(d_cov_np)
 | 
				
			||||||
 | 
					        df_stds = projection_op.backward(d_cov_np)
 | 
				
			||||||
 | 
					        df_stds = np.atleast_2d(df_stds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return d_std.new(df_stds), None, None
 | 
				
			||||||
							
								
								
									
										56
									
								
								fancy_rl/projections/wasserstein_projection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								fancy_rl/projections/wasserstein_projection.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,56 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
 | 
					from typing import Dict, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
 | 
				
			||||||
 | 
					                                     q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					    mean, sqrt = p
 | 
				
			||||||
 | 
					    mean_other, sqrt_other = q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
 | 
				
			||||||
 | 
					    cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if scale_prec:
 | 
				
			||||||
 | 
					        identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
 | 
				
			||||||
 | 
					        sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
 | 
				
			||||||
 | 
					        c = sqrt_inv_other @ cov @ sqrt_inv_other
 | 
				
			||||||
 | 
					        cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return mean_part, cov_part
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WassersteinProjection(BaseProjection):
 | 
				
			||||||
 | 
					    def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
 | 
				
			||||||
 | 
					        super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
 | 
				
			||||||
 | 
					        self.scale_prec = scale_prec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
 | 
					        mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        proj_mean = self._mean_projection(mean, old_mean, mean_part)
 | 
				
			||||||
 | 
					        proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return {"loc": proj_mean, "scale_tril": proj_sqrt}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
 | 
				
			||||||
 | 
					        mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
 | 
				
			||||||
 | 
					        mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
 | 
				
			||||||
 | 
					        w2 = mean_part + cov_part
 | 
				
			||||||
 | 
					        return w2.mean() * self.trust_region_coeff
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        diff = mean - old_mean
 | 
				
			||||||
 | 
					        norm = torch.norm(diff, dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					        return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        diff = sqrt - old_sqrt
 | 
				
			||||||
 | 
					        norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
 | 
				
			||||||
 | 
					        return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)
 | 
				
			||||||
							
								
								
									
										6
									
								
								fancy_rl/projections_legacy/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								fancy_rl/projections_legacy/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import cpp_projection
 | 
				
			||||||
 | 
					except ModuleNotFoundError:
 | 
				
			||||||
 | 
					    from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    from .kl_projection_layer import KLProjectionLayer
 | 
				
			||||||
@ -1,13 +0,0 @@
 | 
				
			|||||||
from .base_projection import BaseProjection
 | 
					 | 
				
			||||||
from .kl_projection import KLProjection
 | 
					 | 
				
			||||||
from .w2_projection import W2Projection
 | 
					 | 
				
			||||||
from .frob_projection import FrobProjection
 | 
					 | 
				
			||||||
from .identity_projection import IdentityProjection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = [
 | 
					 | 
				
			||||||
    "BaseProjection",
 | 
					 | 
				
			||||||
    "KLProjection",
 | 
					 | 
				
			||||||
    "W2Projection",
 | 
					 | 
				
			||||||
    "FrobProjection",
 | 
					 | 
				
			||||||
    "IdentityProjection"
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
@ -1,51 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
from torchrl.modules import TensorDictModule
 | 
					 | 
				
			||||||
from typing import List, Dict, Any
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class BaseProjection(TensorDictModule):
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        in_keys: List[str],
 | 
					 | 
				
			||||||
        out_keys: List[str],
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super().__init__(in_keys=in_keys, out_keys=out_keys)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, tensordict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
					 | 
				
			||||||
        mean, std = self.in_keys
 | 
					 | 
				
			||||||
        projected_mean, projected_std = self.out_keys
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        old_mean = tensordict[mean]
 | 
					 | 
				
			||||||
        old_std = tensordict[std]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        new_mean = tensordict.get(projected_mean, old_mean)
 | 
					 | 
				
			||||||
        new_std = tensordict.get(projected_std, old_std)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        projected_params = self.project(
 | 
					 | 
				
			||||||
            {"mean": new_mean, "std": new_std},
 | 
					 | 
				
			||||||
            {"mean": old_mean, "std": old_std}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        tensordict[projected_mean] = projected_params["mean"]
 | 
					 | 
				
			||||||
        tensordict[projected_std] = projected_params["std"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return tensordict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
					 | 
				
			||||||
        raise NotImplementedError("Subclasses must implement the project method")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def make(cls, projection_type: str, **kwargs: Any) -> 'BaseProjection':
 | 
					 | 
				
			||||||
        if projection_type == "kl":
 | 
					 | 
				
			||||||
            from .kl_projection import KLProjection
 | 
					 | 
				
			||||||
            return KLProjection(**kwargs)
 | 
					 | 
				
			||||||
        elif projection_type == "w2":
 | 
					 | 
				
			||||||
            from .w2_projection import W2Projection
 | 
					 | 
				
			||||||
            return W2Projection(**kwargs)
 | 
					 | 
				
			||||||
        elif projection_type == "frob":
 | 
					 | 
				
			||||||
            from .frob_projection import FrobProjection
 | 
					 | 
				
			||||||
            return FrobProjection(**kwargs)
 | 
					 | 
				
			||||||
        elif projection_type == "identity":
 | 
					 | 
				
			||||||
            from .identity_projection import IdentityProjection
 | 
					 | 
				
			||||||
            return IdentityProjection(**kwargs)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise ValueError(f"Unknown projection type: {projection_type}")
 | 
					 | 
				
			||||||
@ -1,22 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
from .base_projection import BaseProjection
 | 
					 | 
				
			||||||
from typing import Dict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FrobProjection(BaseProjection):
 | 
					 | 
				
			||||||
    def __init__(self, in_keys: list[str], out_keys: list[str], epsilon: float = 1e-3):
 | 
					 | 
				
			||||||
        super().__init__(in_keys=in_keys, out_keys=out_keys)
 | 
					 | 
				
			||||||
        self.epsilon = epsilon
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
					 | 
				
			||||||
        projected_params = {}
 | 
					 | 
				
			||||||
        for key in policy_params.keys():
 | 
					 | 
				
			||||||
            old_param = old_policy_params[key]
 | 
					 | 
				
			||||||
            new_param = policy_params[key]
 | 
					 | 
				
			||||||
            diff = new_param - old_param
 | 
					 | 
				
			||||||
            norm = torch.norm(diff)
 | 
					 | 
				
			||||||
            if norm > self.epsilon:
 | 
					 | 
				
			||||||
                projected_param = old_param + (self.epsilon / norm) * diff
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                projected_param = new_param
 | 
					 | 
				
			||||||
            projected_params[key] = projected_param
 | 
					 | 
				
			||||||
        return projected_params
 | 
					 | 
				
			||||||
@ -1,11 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
from .base_projection import BaseProjection
 | 
					 | 
				
			||||||
from typing import Dict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class IdentityProjection(BaseProjection):
 | 
					 | 
				
			||||||
    def __init__(self, in_keys: list[str], out_keys: list[str]):
 | 
					 | 
				
			||||||
        super().__init__(in_keys=in_keys, out_keys=out_keys)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
					 | 
				
			||||||
        # The identity projection simply returns the new policy parameters without any modification
 | 
					 | 
				
			||||||
        return policy_params
 | 
					 | 
				
			||||||
@ -1,33 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
from typing import Dict, List
 | 
					 | 
				
			||||||
from .base_projection import BaseProjection
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class KLProjection(BaseProjection):
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        in_keys: List[str] = ["mean", "std"],
 | 
					 | 
				
			||||||
        out_keys: List[str] = ["projected_mean", "projected_std"],
 | 
					 | 
				
			||||||
        epsilon: float = 0.1
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super().__init__(in_keys=in_keys, out_keys=out_keys)
 | 
					 | 
				
			||||||
        self.epsilon = epsilon
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
					 | 
				
			||||||
        new_mean, new_std = policy_params["mean"], policy_params["std"]
 | 
					 | 
				
			||||||
        old_mean, old_std = old_policy_params["mean"], old_policy_params["std"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        diff = new_mean - old_mean
 | 
					 | 
				
			||||||
        std_diff = new_std - old_std
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        kl = 0.5 * (torch.sum(torch.square(diff / old_std), dim=-1) +
 | 
					 | 
				
			||||||
                    torch.sum(torch.square(std_diff / old_std), dim=-1) -
 | 
					 | 
				
			||||||
                    new_mean.shape[-1] +
 | 
					 | 
				
			||||||
                    torch.sum(torch.log(new_std / old_std), dim=-1))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        factor = torch.sqrt(self.epsilon / (kl + 1e-8))
 | 
					 | 
				
			||||||
        factor = torch.clamp(factor, max=1.0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        projected_mean = old_mean + factor.unsqueeze(-1) * diff
 | 
					 | 
				
			||||||
        projected_std = old_std + factor.unsqueeze(-1) * std_diff
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return {"mean": projected_mean, "std": projected_std}
 | 
					 | 
				
			||||||
@ -1,73 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
from .base_projection import BaseProjection
 | 
					 | 
				
			||||||
from typing import Dict, Tuple
 | 
					 | 
				
			||||||
from torchrl.modules import TensorDictModule
 | 
					 | 
				
			||||||
from torchrl.distributions import TanhNormal, Delta
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class W2Projection(BaseProjection):
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 in_keys: list[str],
 | 
					 | 
				
			||||||
                 out_keys: list[str],
 | 
					 | 
				
			||||||
                 scale_prec: bool = False):
 | 
					 | 
				
			||||||
        super().__init__(in_keys=in_keys, out_keys=out_keys)
 | 
					 | 
				
			||||||
        self.scale_prec = scale_prec
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 | 
					 | 
				
			||||||
        projected_params = {}
 | 
					 | 
				
			||||||
        for key in policy_params.keys():
 | 
					 | 
				
			||||||
            if key.endswith('.loc'):
 | 
					 | 
				
			||||||
                mean = policy_params[key]
 | 
					 | 
				
			||||||
                old_mean = old_policy_params[key]
 | 
					 | 
				
			||||||
                std_key = key.replace('.loc', '.scale')
 | 
					 | 
				
			||||||
                std = policy_params[std_key]
 | 
					 | 
				
			||||||
                old_std = old_policy_params[std_key]
 | 
					 | 
				
			||||||
                
 | 
					 | 
				
			||||||
                projected_mean, projected_std = self._trust_region_projection(
 | 
					 | 
				
			||||||
                    mean, std, old_mean, old_std
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                
 | 
					 | 
				
			||||||
                projected_params[key] = projected_mean
 | 
					 | 
				
			||||||
                projected_params[std_key] = projected_std
 | 
					 | 
				
			||||||
            elif not key.endswith('.scale'):
 | 
					 | 
				
			||||||
                projected_params[key] = policy_params[key]
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        return projected_params
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _trust_region_projection(self, mean: torch.Tensor, std: torch.Tensor, 
 | 
					 | 
				
			||||||
                                 old_mean: torch.Tensor, old_std: torch.Tensor,
 | 
					 | 
				
			||||||
                                 eps: float = 1e-3, eps_cov: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					 | 
				
			||||||
        mean_part, cov_part = self._gaussian_wasserstein_commutative(mean, std, old_mean, old_std)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        # Project mean
 | 
					 | 
				
			||||||
        mean_mask = mean_part > eps
 | 
					 | 
				
			||||||
        proj_mean = torch.where(mean_mask, 
 | 
					 | 
				
			||||||
                                old_mean + (mean - old_mean) * torch.sqrt(eps / mean_part)[..., None],
 | 
					 | 
				
			||||||
                                mean)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        # Project covariance
 | 
					 | 
				
			||||||
        cov_mask = cov_part > eps_cov
 | 
					 | 
				
			||||||
        eta = torch.ones_like(cov_part)
 | 
					 | 
				
			||||||
        eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / eps_cov) - 1.
 | 
					 | 
				
			||||||
        eta = torch.clamp(eta, -0.9, float('inf'))  # Avoid negative values that could lead to invalid standard deviations
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        proj_std = (std + eta[..., None] * old_std) / (1. + eta[..., None] + 1e-8)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        return proj_mean, proj_std
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _gaussian_wasserstein_commutative(self, mean: torch.Tensor, std: torch.Tensor, 
 | 
					 | 
				
			||||||
                                          old_mean: torch.Tensor, old_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					 | 
				
			||||||
        if self.scale_prec:
 | 
					 | 
				
			||||||
            # Mahalanobis distance for mean
 | 
					 | 
				
			||||||
            mean_part = ((mean - old_mean) ** 2 / (old_std ** 2 + 1e-8)).sum(-1)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # Euclidean distance for mean
 | 
					 | 
				
			||||||
            mean_part = ((mean - old_mean) ** 2).sum(-1)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        # W2 objective for covariance
 | 
					 | 
				
			||||||
        cov_part = (std ** 2 + old_std ** 2 - 2 * std * old_std).sum(-1)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        return mean_part, cov_part
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def make(cls, in_keys: list[str], out_keys: list[str], **kwargs) -> 'W2Projection':
 | 
					 | 
				
			||||||
        return cls(in_keys=in_keys, out_keys=out_keys, **kwargs)
 | 
					 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user