From 78d79cf705b17167a7baa50fec09d09591c1efe9 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 2 Jun 2024 11:57:19 +0200 Subject: [PATCH] Initial code fro projections --- fancy_rl/projections/__init__.py | 6 + fancy_rl/projections/base_projection_layer.py | 191 +++++++++++++ fancy_rl/projections/frob_projection_layer.py | 133 +++++++++ .../projections/identity_projection_layer.py | 5 + fancy_rl/projections/kl_projection_layer.py | 256 ++++++++++++++++++ fancy_rl/projections/w2_projection_layer.py | 164 +++++++++++ 6 files changed, 755 insertions(+) create mode 100644 fancy_rl/projections/__init__.py create mode 100644 fancy_rl/projections/base_projection_layer.py create mode 100644 fancy_rl/projections/frob_projection_layer.py create mode 100644 fancy_rl/projections/identity_projection_layer.py create mode 100644 fancy_rl/projections/kl_projection_layer.py create mode 100644 fancy_rl/projections/w2_projection_layer.py diff --git a/fancy_rl/projections/__init__.py b/fancy_rl/projections/__init__.py new file mode 100644 index 0000000..dc28ef1 --- /dev/null +++ b/fancy_rl/projections/__init__.py @@ -0,0 +1,6 @@ +try: + import cpp_projection +except ModuleNotFoundError: + from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer +else: + from .kl_projection_layer import KLProjectionLayer \ No newline at end of file diff --git a/fancy_rl/projections/base_projection_layer.py b/fancy_rl/projections/base_projection_layer.py new file mode 100644 index 0000000..9aae5ff --- /dev/null +++ b/fancy_rl/projections/base_projection_layer.py @@ -0,0 +1,191 @@ +from typing import Any, Dict, Optional, Type, Union, Tuple, final + +import torch as th + +from fancy_rl.norm import * + +class BaseProjectionLayer(object): + def __init__(self, + mean_bound: float = 0.03, + cov_bound: float = 1e-3, + trust_region_coeff: float = 1.0, + scale_prec: bool = False, + ): + self.mean_bound = mean_bound + self.cov_bound = cov_bound + self.trust_region_coeff = trust_region_coeff + self.scale_prec = scale_prec + self.mean_eq = False + + def __call__(self, p, q, **kwargs): + return self._projection(p, q, eps=self.mean_bound, eps_cov=self.cov_bound, beta=None, **kwargs) + + @final + def _projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, beta: th.Tensor, **kwargs): + return self._trust_region_projection( + p, q, eps, eps_cov, **kwargs) + + def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): + """ + Hook for implementing the specific trust region projection + Args: + p: current distribution + q: old distribution + eps: mean trust region bound + eps_cov: covariance trust region bound + **kwargs: + + Returns: + projected + """ + return p + + def get_trust_region_loss(self, p, proj_p): + # p: + # predicted distribution from network output + # proj_p: + # projected distribution + + proj_mean, proj_chol = get_mean_and_chol(proj_p) + p_target = new_dist_like(p, proj_mean, proj_chol) + kl_diff = self.trust_region_value(p, p_target) + + kl_loss = kl_diff.mean() + + return kl_loss * self.trust_region_coeff + + def trust_region_value(self, p, q): + """ + Computes the KL divergence between two Gaussian distributions p and q_values. + Returns: + full kl divergence + """ + return kl_divergence(p, q) + + def new_dist_like(self, orig_p, mean, cov_cholesky): + assert isinstance(orig_p, Distribution) + p = orig_p.distribution + if isinstance(p, th.distributions.Normal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Normal(mean, cov_cholesky) + elif isinstance(p, th.distributions.Independent): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Independent( + th.distributions.Normal(mean, cov_cholesky), 1) + elif isinstance(p, th.distributions.MultivariateNormal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.MultivariateNormal( + mean, scale_tril=cov_cholesky) + else: + raise Exception('Dist-Type not implemented (of sb3 dist)') + return p_out + +def entropy_inequality_projection(p: th.distributions.Normal, + beta: Union[float, th.Tensor]): + """ + Projects std to satisfy an entropy INEQUALITY constraint. + Args: + p: current distribution + beta: target entropy for EACH std or general bound for all stds + + Returns: + projected std that satisfies the entropy bound + """ + mean, std = p.mean, p.stddev + k = std.shape[-1] + batch_shape = std.shape[:-2] + + ent = p.entropy() + mask = ent < beta + + # if nothing has to be projected skip computation + if (~mask).all(): + return p + + alpha = th.ones(batch_shape, dtype=std.dtype, device=std.device) + alpha[mask] = th.exp((beta[mask] - ent[mask]) / k) + + proj_std = th.einsum('ijk,i->ijk', std, alpha) + new_mean, new_std = mean, th.where(mask[..., None, None], proj_std, std) + return th.distributions.Normal(new_mean, new_std) + + +def entropy_equality_projection(p: th.distributions.Normal, + beta: Union[float, th.Tensor]): + """ + Projects std to satisfy an entropy EQUALITY constraint. + Args: + p: current distribution + beta: target entropy for EACH std or general bound for all stds + + Returns: + projected std that satisfies the entropy bound + """ + mean, std = p.mean, p.stddev + k = std.shape[-1] + + ent = p.entropy() + alpha = th.exp((beta - ent) / k) + proj_std = th.einsum('ijk,i->ijk', std, alpha) + new_mean, new_std = mean, proj_std + return th.distributions.Normal(new_mean, new_std) + + +def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor): + """ + Projects the mean based on the Mahalanobis objective and trust region. + Args: + mean: current mean vectors + old_mean: old mean vectors + maha: Mahalanobis distance between the two mean vectors + eps: trust region bound + + Returns: + projected mean that satisfies the trust region + """ + batch_shape = mean.shape[:-1] + mask = maha > eps + + ################################################################################################################ + # mean projection maha + + # if nothing has to be projected skip computation + if mask.any(): + omega = th.ones(batch_shape, dtype=mean.dtype, device=mean.device) + omega[mask] = th.sqrt(maha[mask] / eps) - 1. + omega = th.max(-omega, omega)[..., None] + + m = (mean + omega * old_mean) / (1 + omega + 1e-16) + proj_mean = th.where(mask[..., None], m, mean) + else: + proj_mean = mean + + return proj_mean + + +def mean_equality_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor): + """ + Projections the mean based on the Mahalanobis objective and trust region for an EQUALITY constraint. + Args: + mean: current mean vectors + old_mean: old mean vectors + maha: Mahalanobis distance between the two mean vectors + eps: trust region bound + Returns: + projected mean that satisfies the trust region + """ + + maha[maha == 0] += 1e-16 + omega = th.sqrt(maha / eps) - 1. + omega = omega[..., None] + + proj_mean = (mean + omega * old_mean) / (1 + omega + 1e-16) + + return proj_mean + + +class ITPALExceptionLayer(BaseProjectionLayer): + def __init__(self, + *args, **kwargs + ): + raise Exception('To be able to use KL projections, ITPAL must be installed: https://github.com/ALRhub/ITPAL.') \ No newline at end of file diff --git a/fancy_rl/projections/frob_projection_layer.py b/fancy_rl/projections/frob_projection_layer.py new file mode 100644 index 0000000..2309ab2 --- /dev/null +++ b/fancy_rl/projections/frob_projection_layer.py @@ -0,0 +1,133 @@ +import torch as th +from typing import Tuple + +from .base_projection_layer import BaseProjectionLayer, mean_projection + +from ..misc.norm import mahalanobis, frob_sq +from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like, has_diag_cov + + +class FrobeniusProjectionLayer(BaseProjectionLayer): + + def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): + """ + Stolen from Fabian's Code (Public Version) + + Runs Frobenius projection layer and constructs cholesky of covariance + + Args: + policy: policy instance + p: current distribution + q: old distribution + eps: (modified) kl bound/ kl bound for mean part + eps_cov: (modified) kl bound for cov part + beta: (modified) entropy bound + **kwargs: + Returns: mean, cov cholesky + """ + + mean, chol = get_mean_and_chol(p, expand=True) + old_mean, old_chol = get_mean_and_chol(q, expand=True) + batch_shape = mean.shape[:-1] + + #################################################################################################################### + # precompute mean and cov part of frob projection, which are used for the projection. + mean_part, cov_part, cov, cov_old = gaussian_frobenius( + p, q, self.scale_prec, True) + + ################################################################################################################ + # mean projection maha/euclidean + + proj_mean = mean_projection(mean, old_mean, mean_part, eps) + + ################################################################################################################ + # cov projection frobenius + + cov_mask = cov_part > eps_cov + + if cov_mask.any(): + eta = th.ones(batch_shape, dtype=chol.dtype, device=chol.device) + eta[cov_mask] = th.sqrt(cov_part[cov_mask] / eps_cov) - 1. + eta = th.max(-eta, eta) + + new_cov = (cov + th.einsum('i,ijk->ijk', eta, cov_old) + ) / (1. + eta + 1e-16)[..., None, None] + proj_chol = th.where( + cov_mask[..., None, None], th.linalg.cholesky(new_cov), chol) + else: + proj_chol = chol + + if has_diag_cov(p): + proj_chol = th.diagonal(proj_chol, dim1=-2, dim2=-1) + + proj_p = new_dist_like(p, proj_mean, proj_chol) + return proj_p + + def trust_region_value(self, p, q): + """ + Stolen from Fabian's Code (Public Version) + + Computes the Frobenius metric between two Gaussian distributions p and q. + Args: + policy: policy instance + p: current distribution + q: old distribution + Returns: + mean and covariance part of Frobenius metric + """ + return gaussian_frobenius(p, q, self.scale_prec) + + def get_trust_region_loss(self, p, proj_p): + """ + Stolen from Fabian's Code (Public Version) + """ + + mean_diff, _ = self.trust_region_value(p, proj_p) + if False and policy.contextual_std: + # Compute MSE here, because we found the Frobenius norm tends to generate values that explode for the cov + p_mean, proj_p_mean = p.mean, proj_p.mean + cov_diff = (p_mean - proj_p_mean).pow(2).sum([-1, -2]) + delta_loss = (mean_diff + cov_diff).mean() + else: + delta_loss = mean_diff.mean() + + return delta_loss * self.trust_region_coeff + + +def gaussian_frobenius(p, q, scale_prec: bool = False, return_cov: bool = False): + """ + Stolen from Fabian' Code (Public Version) + + Compute (p - q_values) (L_oL_o^T)^-1 (p - 1)^T + |LL^T - L_oL_o^T|_F^2 with p,q_values ~ N(y, LL^T) + Args: + policy: current policy + p: mean and chol of gaussian p + q: mean and chol of gaussian q_values + return_cov: return cov matrices for further computations + scale_prec: scale objective with precision matrix + Returns: mahalanobis distance, squared frobenius norm + """ + + mean, chol = get_mean_and_chol(p) + mean_other, chol_other = get_mean_and_chol(q) + + if scale_prec: + # maha objective for mean + mean_part = mahalanobis(mean, mean_other, chol_other) + else: + # euclidean distance for mean + # mean_part = ch.norm(mean_other - mean, ord=2, axis=1) ** 2 + mean_part = ((mean_other - mean) ** 2).sum(1) + + # frob objective for cov + cov = get_cov(p) + cov_other = get_cov(q) + diff = cov_other - cov + # Matrix is real symmetric PSD, therefore |A @ A^H|^2_F = tr{A @ A^H} = tr{A @ A} + #cov_part = torch_batched_trace(diff @ diff) + cov_part = frob_sq(diff, is_spd=True) + + if return_cov: + return mean_part, cov_part, cov, cov_other + + return mean_part, cov_part diff --git a/fancy_rl/projections/identity_projection_layer.py b/fancy_rl/projections/identity_projection_layer.py new file mode 100644 index 0000000..62216db --- /dev/null +++ b/fancy_rl/projections/identity_projection_layer.py @@ -0,0 +1,5 @@ +from .base_projection_layer import BaseProjectionLayer + +class IdentityProjectionLayer(BaseProjectionLayer): + def project_from_rollouts(self, dist, rollout_data, **kwargs): + return dist, dist diff --git a/fancy_rl/projections/kl_projection_layer.py b/fancy_rl/projections/kl_projection_layer.py new file mode 100644 index 0000000..cc38126 --- /dev/null +++ b/fancy_rl/projections/kl_projection_layer.py @@ -0,0 +1,256 @@ +from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_cov, is_contextual, new_dist_like, has_diag_cov +from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection + +import cpp_projection +import numpy as np +import torch as th +from typing import Tuple, Any + +from ..misc.norm import mahalanobis + +MAX_EVAL = 1000 + + +class KLProjectionLayer(BaseProjectionLayer): + """ + Stolen from Fabian's Code (Private Version) + """ + + def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): + """ + Stolen from Fabian's Code (Private Version) + + runs kl projection layer and constructs sqrt of covariance + Args: + **kwargs: + policy: policy instance + p: current distribution + q: old distribution + eps: (modified) kl bound/ kl bound for mean part + eps_cov: (modified) kl bound for cov part + + Returns: + mean, cov sqrt + """ + mean, chol = get_mean_and_chol(p, expand=True) + old_mean, old_chol = get_mean_and_chol(q, expand=True) + + ################################################################################################################ + # project mean with closed form + # orig code: mean_part, _ = gaussian_kl(policy, p, q) + # But the mean_part is just the mahalanobis dist: + mean_part = mahalanobis(mean, old_mean, old_chol) + if self.mean_eq: + proj_mean = mean_equality_projection( + mean, old_mean, mean_part, eps) + else: + proj_mean = mean_projection(mean, old_mean, mean_part, eps) + + if has_diag_cov(p): + cov_diag = get_diag_cov_vec(p) + old_cov_diag = get_diag_cov_vec(q) + proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov_diag, + old_cov_diag, + eps_cov) + proj_chol = proj_cov.sqrt() # .diag_embed() + else: + cov = get_cov(p) + old_cov = get_cov(q) + proj_cov = KLProjectionGradFunctionCovOnly.apply( + cov, old_cov, chol, old_chol, eps_cov) + proj_chol = th.linalg.cholesky(proj_cov) + proj_p = new_dist_like(p, proj_mean, proj_chol) + return proj_p + + +class KLProjectionGradFunctionCovOnly(th.autograd.Function): + projection_op = None + + @staticmethod + def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL): + if not KLProjectionGradFunctionCovOnly.projection_op: + KLProjectionGradFunctionCovOnly.projection_op = \ + cpp_projection.BatchedCovOnlyProjection( + batch_shape, dim, max_eval=max_eval) + return KLProjectionGradFunctionCovOnly.projection_op + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: + #std, old_std, eps_cov = args + cov, old_cov, chol, old_chol, eps_cov = args + + batch_shape = chol.shape[0] + dim = chol.shape[-1] + + cov_np = cov.cpu().detach().numpy() + old_cov_np = old_cov.cpu().detach().numpy() + chol_np = chol.cpu().detach().numpy() + old_chol_np = old_chol.cpu().detach().numpy() + # eps = eps_cov.cpu().detach().numpy().astype(old_std_np.dtype) * \ + eps = eps_cov * \ + np.ones(batch_shape, dtype=old_chol_np.dtype) + + p_op = KLProjectionGradFunctionCovOnly.get_projection_op( + batch_shape, dim) + ctx.proj = p_op + + proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np) + + return th.Tensor(proj_cov) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + projection_op = ctx.proj + d_std, = grad_outputs + + d_std_np = d_std.cpu().detach().numpy() + d_std_np = np.atleast_2d(d_std_np) + df_stds = projection_op.backward(d_std_np) + df_stds = np.atleast_2d(df_stds) + + return d_std.new(df_stds), None, None, None, None + + +class KLProjectionGradFunctionDiagCovOnly(th.autograd.Function): + projection_op = None + + @staticmethod + def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL): + if not KLProjectionGradFunctionDiagCovOnly.projection_op: + KLProjectionGradFunctionDiagCovOnly.projection_op = \ + cpp_projection.BatchedDiagCovOnlyProjection( + batch_shape, dim, max_eval=max_eval) + return KLProjectionGradFunctionDiagCovOnly.projection_op + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: + cov, old_std_np, eps_cov = args + + batch_shape = cov.shape[0] + dim = cov.shape[-1] + + std_np = cov.to('cpu').detach().numpy() + old_std_np = old_std_np.to('cpu').detach().numpy() + # eps = eps_cov.to('cpu').detach().numpy().astype(old_std_np.dtype) * np.ones(batch_shape, dtype=old_std_np.dtype) + eps = eps_cov * np.ones(batch_shape, dtype=old_std_np.dtype) + + p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op( + batch_shape, dim) + ctx.proj = p_op + + try: + proj_std = p_op.forward(eps, old_std_np, std_np) + except: + proj_std = std_np + + return cov.new(proj_std) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + projection_op = ctx.proj + d_std, = grad_outputs + + d_std_np = d_std.to('cpu').detach().numpy() + d_std_np = np.atleast_2d(d_std_np) + df_stds = projection_op.backward(d_std_np) + df_stds = np.atleast_2d(df_stds) + + return d_std.new(df_stds), None, None + + +class KLProjectionGradFunctionDiagSplit(th.autograd.Function): + projection_op = None + + @staticmethod + def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL): + if not KLProjectionGradFunctionDiagSplit.projection_op: + KLProjectionGradFunctionDiagSplit.projection_op = \ + cpp_projection.BatchedSplitDiagMoreProjection( + batch_shape, dim, max_eval=max_eval) + return KLProjectionGradFunctionDiagSplit.projection_op + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: + mean, cov, old_mean, old_cov, eps_mu, eps_sigma = args + + batch_shape, dim = mean.shape + + mean_np = mean.detach().numpy() + cov_np = cov.detach().numpy() + old_mean = old_mean.detach().numpy() + old_cov = old_cov.detach().numpy() + eps_mu = eps_mu * np.ones(batch_shape) + eps_sigma = eps_sigma * np.ones(batch_shape) + + # p_op = cpp_projection.BatchedSplitDiagMoreProjection(batch_shape, dim, max_eval=100) + p_op = KLProjectionGradFunctionDiagSplit.get_projection_op( + batch_shape, dim) + + try: + proj_mean, proj_cov = p_op.forward( + eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np) + except Exception: + # try a second time + proj_mean, proj_cov = p_op.forward( + eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np) + ctx.proj = p_op + + return mean.new(proj_mean), cov.new(proj_cov) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + p_op = ctx.proj + d_means, d_std = grad_outputs + + d_std_np = d_std.detach().numpy() + d_std_np = np.atleast_2d(d_std_np) + d_mean_np = d_means.detach().numpy() + dtarget_means, dtarget_covs = p_op.backward(d_mean_np, d_std_np) + dtarget_covs = np.atleast_2d(dtarget_covs) + + return d_means.new(dtarget_means), d_std.new(dtarget_covs), None, None, None, None + + +class KLProjectionGradFunctionJoint(th.autograd.Function): + projection_op = None + + @staticmethod + def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL): + if not KLProjectionGradFunctionJoint.projection_op: + KLProjectionGradFunctionJoint.projection_op = \ + cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False, + max_eval=max_eval) + return KLProjectionGradFunctionJoint.projection_op + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: + mean, cov, old_mean, old_cov, eps, beta = args + + batch_shape, dim = mean.shape + + mean_np = mean.detach().numpy() + cov_np = cov.detach().numpy() + old_mean = old_mean.detach().numpy() + old_cov = old_cov.detach().numpy() + eps = eps * np.ones(batch_shape) + beta = beta.detach().numpy() * np.ones(batch_shape) + + # projection_op = cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False) + # ctx.proj = projection_op + + p_op = KLProjectionGradFunctionJoint.get_projection_op( + batch_shape, dim) + ctx.proj = p_op + + proj_mean, proj_cov = p_op.forward( + eps, beta, old_mean, old_cov, mean_np, cov_np) + + return mean.new(proj_mean), cov.new(proj_cov) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + projection_op = ctx.proj + d_means, d_covs = grad_outputs + df_means, df_covs = projection_op.backward( + d_means.detach().numpy(), d_covs.detach().numpy()) + return d_means.new(df_means), d_means.new(df_covs), None, None, None, None diff --git a/fancy_rl/projections/w2_projection_layer.py b/fancy_rl/projections/w2_projection_layer.py new file mode 100644 index 0000000..7296f28 --- /dev/null +++ b/fancy_rl/projections/w2_projection_layer.py @@ -0,0 +1,164 @@ +import numpy as np +import torch as th +from typing import Tuple, Any + +from ..misc.norm import mahalanobis + +from .base_projection_layer import BaseProjectionLayer, mean_projection + +from ..misc.norm import mahalanobis, _batch_trace +from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, has_diag_cov + +from stable_baselines3.common.distributions import Distribution + + +class WassersteinProjectionLayer(BaseProjectionLayer): + """ + Stolen from Fabian's Code (Public Version) + """ + + def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): + """ + Runs commutative Wasserstein projection layer and constructs sqrt of covariance + Args: + policy: policy instance + p: current distribution + q: old distribution + eps: (modified) kl bound/ kl bound for mean part + eps_cov: (modified) kl bound for cov part + **kwargs: + + Returns: + mean, cov sqrt + """ + + mean, sqrt = get_mean_and_sqrt(p, expand=True) + old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True) + batch_shape = mean.shape[:-1] + + #################################################################################################################### + # precompute mean and cov part of W2, which are used for the projection. + # Both parts differ based on precision scaling. + # If activated, the mean part is the maha distance and the cov has a more complex term in the inner parenthesis. + mean_part, cov_part = gaussian_wasserstein_commutative( + p, q, self.scale_prec) + + #################################################################################################################### + # project mean (w/ or w/o precision scaling) + proj_mean = mean_projection(mean, old_mean, mean_part, eps) + + #################################################################################################################### + # project covariance (w/ or w/o precision scaling) + + cov_mask = cov_part > eps_cov + + if cov_mask.any(): + # gradient issue with ch.where, it executes both paths and gives NaN gradient. + eta = th.ones(batch_shape, dtype=sqrt.dtype, device=sqrt.device) + eta[cov_mask] = th.sqrt(cov_part[cov_mask] / eps_cov) - 1. + eta = th.max(-eta, eta) + + new_sqrt = (sqrt + th.einsum('i,ijk->ijk', eta, old_sqrt) + ) / (1. + eta + 1e-16)[..., None, None] + proj_sqrt = th.where(cov_mask[..., None, None], new_sqrt, sqrt) + else: + proj_sqrt = sqrt + + if has_diag_cov(p): + proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1) + + proj_p = self.new_dist_like(p, proj_mean, proj_sqrt) + return proj_p + + def trust_region_value(self, p, q): + """ + Computes the Wasserstein distance between two Gaussian distributions p and q. + Args: + policy: policy instance + p: current distribution + q: old distribution + Returns: + mean and covariance part of Wasserstein distance + """ + mean_part, cov_part = gaussian_wasserstein_commutative( + p, q, scale_prec=self.scale_prec) + return mean_part + cov_part + + def get_trust_region_loss(self, p, proj_p): + # p: + # predicted distribution from network output + # proj_p: + # projected distribution + + proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p) + p_target = self.new_dist_like(p, proj_mean, proj_sqrt) + kl_diff = self.trust_region_value(p, p_target) + + kl_loss = kl_diff.mean() + + return kl_loss * self.trust_region_coeff + + def new_dist_like(self, orig_p, mean, cov_sqrt): + assert isinstance(orig_p, Distribution) + p = orig_p.distribution + if isinstance(p, th.distributions.Normal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Normal(mean, cov_sqrt) + elif isinstance(p, th.distributions.Independent): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Independent( + th.distributions.Normal(mean, cov_sqrt), 1) + elif isinstance(p, th.distributions.MultivariateNormal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.MultivariateNormal( + mean, scale_tril=cov_sqrt, validate_args=False) + else: + raise Exception('Dist-Type not implemented (of sb3 dist)') + p_out.cov_sqrt = cov_sqrt + return p_out + + +def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor, th.Tensor]: + """ + Compute mean part and cov part of W_2(p || q_values) with p,q_values ~ N(y, SS). + This version DOES assume commutativity of both distributions, i.e. covariance matrices. + This is less general and assumes both distributions are somewhat close together. + When scale_prec is true scale both distributions with old precision matrix. + Args: + policy: current policy + p: mean and sqrt of gaussian p + q: mean and sqrt of gaussian q_values + scale_prec: scale objective by old precision matrix. + This penalizes directions based on old uncertainty/covariance. + Returns: mean part of W2, cov part of W2 + """ + mean, sqrt = get_mean_and_sqrt(p, expand=True) + mean_other, sqrt_other = get_mean_and_sqrt(q, expand=True) + + if scale_prec: + # maha objective for mean + mean_part = mahalanobis(mean, mean_other, sqrt_other) + else: + # euclidean distance for mean + # mean_part = ch.norm(mean_other - mean, ord=2, axis=1) ** 2 + mean_part = ((mean_other - mean) ** 2).sum(1) + + cov = get_cov(p) + if scale_prec and False: + # cov constraint scaled with precision of old dist + batch_dim, dim = mean.shape + + identity = th.eye(dim, dtype=sqrt.dtype, device=sqrt.device) + sqrt_inv_other = th.linalg.solve(sqrt_other, identity) + c = sqrt_inv_other @ cov @ sqrt_inv_other + + cov_part = _batch_trace( + identity + c - 2 * sqrt_inv_other @ sqrt) + + else: + # W2 objective for cov assuming normal W2 objective for mean + cov_other = get_cov(q) + cov_part = _batch_trace( + cov_other + cov - 2 * th.bmm(sqrt_other, sqrt)) + + return mean_part, cov_part