Fixing issues with projections
This commit is contained in:
parent
71cb8593d9
commit
651ef1522f
@ -1,16 +1,71 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict
|
from torch import nn
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
class BaseProjection(ABC, torch.nn.Module):
|
class BaseProjection(nn.Module, ABC):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str]):
|
def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._validate_in_keys(in_keys)
|
||||||
|
self._validate_out_keys(out_keys)
|
||||||
self.in_keys = in_keys
|
self.in_keys = in_keys
|
||||||
self.out_keys = out_keys
|
self.out_keys = out_keys
|
||||||
|
self.trust_region_coeff = trust_region_coeff
|
||||||
|
self.mean_bound = mean_bound
|
||||||
|
self.cov_bound = cov_bound
|
||||||
|
self.full_cov = "scale_tril" in in_keys
|
||||||
|
self.contextual_std = contextual_std
|
||||||
|
|
||||||
|
def _validate_in_keys(self, keys: List[str]):
|
||||||
|
valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"}
|
||||||
|
if not set(keys).issubset(valid_keys):
|
||||||
|
raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}")
|
||||||
|
if "loc" not in keys or "old_loc" not in keys:
|
||||||
|
raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys")
|
||||||
|
if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys):
|
||||||
|
raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'")
|
||||||
|
|
||||||
|
def _validate_out_keys(self, keys: List[str]):
|
||||||
|
valid_keys = {"loc", "scale", "scale_tril"}
|
||||||
|
if not set(keys).issubset(valid_keys):
|
||||||
|
raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}")
|
||||||
|
if "loc" not in keys:
|
||||||
|
raise ValueError("'loc' must be included in out_keys")
|
||||||
|
if "scale" not in keys and "scale_tril" not in keys:
|
||||||
|
raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
@abstractmethod
|
||||||
return self.project(policy_params, old_policy_params)
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, tensordict):
|
||||||
|
policy_params = {}
|
||||||
|
old_policy_params = {}
|
||||||
|
|
||||||
|
for key in self.in_keys:
|
||||||
|
if key not in tensordict:
|
||||||
|
raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}")
|
||||||
|
|
||||||
|
if key.startswith("old_"):
|
||||||
|
old_policy_params[key[4:]] = tensordict[key]
|
||||||
|
else:
|
||||||
|
policy_params[key] = tensordict[key]
|
||||||
|
|
||||||
|
projected_params = self.project(policy_params, old_policy_params)
|
||||||
|
return projected_params
|
||||||
|
|
||||||
|
def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
if not self.full_cov:
|
||||||
|
return torch.diag_embed(params["scale"].pow(2))
|
||||||
|
else:
|
||||||
|
return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2))
|
||||||
|
|
||||||
|
def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.full_cov:
|
||||||
|
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
|
||||||
|
else:
|
||||||
|
return torch.linalg.cholesky(cov)
|
@ -1,33 +1,34 @@
|
|||||||
import torch
|
import torch
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
class FrobeniusProjection(BaseProjection):
|
class FrobeniusProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
self.scale_prec = scale_prec
|
self.scale_prec = scale_prec
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
|
old_mean = old_policy_params["loc"]
|
||||||
|
|
||||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
cov = self._calc_covariance(policy_params)
|
||||||
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
|
old_cov = self._calc_covariance(old_policy_params)
|
||||||
|
|
||||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||||
|
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||||
|
|
||||||
proj_chol = torch.linalg.cholesky(proj_cov)
|
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
|
||||||
return {"loc": proj_mean, "scale_tril": proj_chol}
|
return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
proj_mean = proj_policy_params["loc"]
|
||||||
|
|
||||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
cov = self._calc_covariance(policy_params)
|
||||||
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
|
proj_cov = self._calc_covariance(proj_policy_params)
|
||||||
|
|
||||||
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
|
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
|
||||||
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
|
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
|
||||||
|
@ -3,8 +3,8 @@ from .base_projection import BaseProjection
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
class IdentityProjection(BaseProjection):
|
class IdentityProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
return policy_params
|
return policy_params
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
import cpp_projection
|
import cpp_projection
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from typing import Dict, Tuple, Any
|
from typing import Dict, Tuple, Any
|
||||||
|
|
||||||
MAX_EVAL = 1000
|
MAX_EVAL = 1000
|
||||||
@ -10,57 +11,65 @@ def get_numpy(tensor):
|
|||||||
return tensor.detach().cpu().numpy()
|
return tensor.detach().cpu().numpy()
|
||||||
|
|
||||||
class KLProjection(BaseProjection):
|
class KLProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
self.is_diag = is_diag
|
|
||||||
self.contextual_std = contextual_std
|
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||||
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
|
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
|
||||||
|
|
||||||
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
|
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
|
||||||
|
|
||||||
if not self.contextual_std:
|
if not self.contextual_std:
|
||||||
std = std[:1]
|
scale_or_tril = scale_or_tril[:1]
|
||||||
old_std = old_std[:1]
|
old_scale_or_tril = old_scale_or_tril[:1]
|
||||||
cov_part = cov_part[:1]
|
cov_part = cov_part[:1]
|
||||||
|
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_std = self._cov_projection(std, old_std, cov_part)
|
proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
|
||||||
|
|
||||||
if not self.contextual_std:
|
if not self.contextual_std:
|
||||||
proj_std = proj_std.expand(mean.shape[0], -1, -1)
|
proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
|
||||||
|
|
||||||
return {"loc": proj_mean, "scale_tril": proj_std}
|
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||||
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]]
|
||||||
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
|
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||||
return kl.mean() * self.trust_region_coeff
|
return kl.mean() * self.trust_region_coeff
|
||||||
|
|
||||||
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
mean, std = p
|
mean, scale_or_tril = p
|
||||||
mean_other, std_other = q
|
mean_other, scale_or_tril_other = q
|
||||||
k = mean.shape[-1]
|
k = mean.shape[-1]
|
||||||
|
|
||||||
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
|
maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
|
||||||
|
|
||||||
det_term = self._log_determinant(std)
|
det_term = self._log_determinant(scale_or_tril)
|
||||||
det_term_other = self._log_determinant(std_other)
|
det_term_other = self._log_determinant(scale_or_tril_other)
|
||||||
|
|
||||||
|
if self.full_cov:
|
||||||
|
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False))
|
||||||
|
else:
|
||||||
|
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
|
||||||
|
|
||||||
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
|
|
||||||
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
|
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
|
||||||
|
|
||||||
return maha_part, cov_part
|
return maha_part, cov_part
|
||||||
|
|
||||||
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||||
diff = x - y
|
diff = x - y
|
||||||
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
|
if self.full_cov:
|
||||||
|
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1)
|
||||||
|
else:
|
||||||
|
return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
|
||||||
|
|
||||||
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
|
def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||||
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
if self.full_cov:
|
||||||
|
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
||||||
|
else:
|
||||||
|
return 2 * torch.log(scale_or_tril).sum(-1)
|
||||||
|
|
||||||
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.sum(x.pow(2), dim=(-2, -1))
|
return torch.sum(x.pow(2), dim=(-2, -1))
|
||||||
@ -68,49 +77,45 @@ class KLProjection(BaseProjection):
|
|||||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||||
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
|
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
|
||||||
|
|
||||||
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||||
cov = torch.matmul(std, std.transpose(-1, -2))
|
if self.full_cov:
|
||||||
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
|
cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
|
||||||
|
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
|
||||||
if self.is_diag:
|
|
||||||
mask = cov_part > self.cov_bound
|
|
||||||
proj_std = torch.zeros_like(std)
|
|
||||||
proj_std[~mask] = std[~mask]
|
|
||||||
try:
|
|
||||||
if mask.any():
|
|
||||||
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
|
|
||||||
old_cov.diagonal(dim1=-2, dim2=-1),
|
|
||||||
self.cov_bound)
|
|
||||||
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
|
|
||||||
if is_invalid.any():
|
|
||||||
proj_std[is_invalid] = old_std[is_invalid]
|
|
||||||
mask &= ~is_invalid
|
|
||||||
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
|
|
||||||
except Exception as e:
|
|
||||||
proj_std = old_std
|
|
||||||
else:
|
else:
|
||||||
try:
|
cov = scale_or_tril.pow(2)
|
||||||
|
old_cov = old_scale_or_tril.pow(2)
|
||||||
|
|
||||||
mask = cov_part > self.cov_bound
|
mask = cov_part > self.cov_bound
|
||||||
proj_std = torch.zeros_like(std)
|
proj_scale_or_tril = torch.zeros_like(scale_or_tril)
|
||||||
proj_std[~mask] = std[~mask]
|
proj_scale_or_tril[~mask] = scale_or_tril[~mask]
|
||||||
|
|
||||||
|
try:
|
||||||
if mask.any():
|
if mask.any():
|
||||||
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
|
if self.full_cov:
|
||||||
|
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound)
|
||||||
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
||||||
if is_invalid.any():
|
if is_invalid.any():
|
||||||
proj_std[is_invalid] = old_std[is_invalid]
|
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||||
mask &= ~is_invalid
|
mask &= ~is_invalid
|
||||||
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
||||||
failed_mask = failed_mask.bool()
|
failed_mask = failed_mask.bool()
|
||||||
if failed_mask.any():
|
if failed_mask.any():
|
||||||
proj_std[failed_mask] = old_std[failed_mask]
|
proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask]
|
||||||
|
else:
|
||||||
|
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound)
|
||||||
|
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
|
||||||
|
if is_invalid.any():
|
||||||
|
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||||
|
mask &= ~is_invalid
|
||||||
|
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
logging.error('Projection failed, taking old cholesky for projection.')
|
logging.error('Projection failed, taking old scale_or_tril for projection.')
|
||||||
print("Projection failed, taking old cholesky for projection.")
|
print("Projection failed, taking old scale_or_tril for projection.")
|
||||||
proj_std = old_std
|
proj_scale_or_tril = old_scale_or_tril
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return proj_std
|
return proj_scale_or_tril
|
||||||
|
|
||||||
|
|
||||||
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
|
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
|
||||||
|
@ -1,56 +1,86 @@
|
|||||||
import torch
|
import torch
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
'Converts' scale_tril to scale_sqrt.
|
||||||
|
|
||||||
|
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
|
||||||
|
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
|
||||||
|
"""
|
||||||
|
return scale_tril
|
||||||
|
|
||||||
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
|
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
|
||||||
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
mean, sqrt = p
|
mean, scale_or_sqrt = p
|
||||||
mean_other, sqrt_other = q
|
mean_other, scale_or_sqrt_other = q
|
||||||
|
|
||||||
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
||||||
|
|
||||||
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
|
if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
|
||||||
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
|
cov = scale_or_sqrt.pow(2)
|
||||||
|
cov_other = scale_or_sqrt_other.pow(2)
|
||||||
if scale_prec:
|
if scale_prec:
|
||||||
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
|
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||||
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
|
sqrt_inv_other = 1 / scale_or_sqrt_other
|
||||||
c = sqrt_inv_other @ cov @ sqrt_inv_other
|
c = sqrt_inv_other.pow(2) * cov
|
||||||
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
|
cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1)
|
||||||
else:
|
else:
|
||||||
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
|
cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1)
|
||||||
|
else: # Full covariance case
|
||||||
|
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
|
||||||
|
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
|
||||||
|
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
|
||||||
|
if scale_prec:
|
||||||
|
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||||
|
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity)
|
||||||
|
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2)
|
||||||
|
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
|
||||||
|
else:
|
||||||
|
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
|
||||||
|
|
||||||
return mean_part, cov_part
|
return mean_part, cov_part
|
||||||
|
|
||||||
class WassersteinProjection(BaseProjection):
|
class WassersteinProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
self.scale_prec = scale_prec
|
self.scale_prec = scale_prec
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
|
old_mean = old_policy_params["loc"]
|
||||||
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||||
|
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]])
|
||||||
|
|
||||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
|
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec)
|
||||||
|
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
|
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
|
||||||
|
|
||||||
return {"loc": proj_mean, "scale_tril": proj_sqrt}
|
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
proj_mean = proj_policy_params["loc"]
|
||||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||||
|
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
|
||||||
|
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
|
||||||
w2 = mean_part + cov_part
|
w2 = mean_part + cov_part
|
||||||
return w2.mean() * self.trust_region_coeff
|
return w2.mean() * self.trust_region_coeff
|
||||||
|
|
||||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||||
diff = mean - old_mean
|
diff = mean - old_mean
|
||||||
norm = torch.norm(diff, dim=-1, keepdim=True)
|
norm = torch.sqrt(mean_part)
|
||||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
|
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
|
||||||
|
|
||||||
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||||
diff = sqrt - old_sqrt
|
if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
|
||||||
|
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||||
|
norm = torch.sqrt(cov_part)
|
||||||
|
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
|
||||||
|
else: # Full covariance case
|
||||||
|
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||||
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
|
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
|
||||||
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)
|
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt)
|
Loading…
Reference in New Issue
Block a user