diff --git a/fancy_rl/projections_legacy/__init__.py b/fancy_rl/projections_legacy/__init__.py deleted file mode 100644 index dc28ef1..0000000 --- a/fancy_rl/projections_legacy/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -try: - import cpp_projection -except ModuleNotFoundError: - from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer -else: - from .kl_projection_layer import KLProjectionLayer \ No newline at end of file diff --git a/fancy_rl/projections_legacy/base_projection_layer.py b/fancy_rl/projections_legacy/base_projection_layer.py deleted file mode 100644 index 9aae5ff..0000000 --- a/fancy_rl/projections_legacy/base_projection_layer.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import Any, Dict, Optional, Type, Union, Tuple, final - -import torch as th - -from fancy_rl.norm import * - -class BaseProjectionLayer(object): - def __init__(self, - mean_bound: float = 0.03, - cov_bound: float = 1e-3, - trust_region_coeff: float = 1.0, - scale_prec: bool = False, - ): - self.mean_bound = mean_bound - self.cov_bound = cov_bound - self.trust_region_coeff = trust_region_coeff - self.scale_prec = scale_prec - self.mean_eq = False - - def __call__(self, p, q, **kwargs): - return self._projection(p, q, eps=self.mean_bound, eps_cov=self.cov_bound, beta=None, **kwargs) - - @final - def _projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, beta: th.Tensor, **kwargs): - return self._trust_region_projection( - p, q, eps, eps_cov, **kwargs) - - def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): - """ - Hook for implementing the specific trust region projection - Args: - p: current distribution - q: old distribution - eps: mean trust region bound - eps_cov: covariance trust region bound - **kwargs: - - Returns: - projected - """ - return p - - def get_trust_region_loss(self, p, proj_p): - # p: - # predicted distribution from network output - # proj_p: - # projected distribution - - proj_mean, proj_chol = get_mean_and_chol(proj_p) - p_target = new_dist_like(p, proj_mean, proj_chol) - kl_diff = self.trust_region_value(p, p_target) - - kl_loss = kl_diff.mean() - - return kl_loss * self.trust_region_coeff - - def trust_region_value(self, p, q): - """ - Computes the KL divergence between two Gaussian distributions p and q_values. - Returns: - full kl divergence - """ - return kl_divergence(p, q) - - def new_dist_like(self, orig_p, mean, cov_cholesky): - assert isinstance(orig_p, Distribution) - p = orig_p.distribution - if isinstance(p, th.distributions.Normal): - p_out = orig_p.__class__(orig_p.action_dim) - p_out.distribution = th.distributions.Normal(mean, cov_cholesky) - elif isinstance(p, th.distributions.Independent): - p_out = orig_p.__class__(orig_p.action_dim) - p_out.distribution = th.distributions.Independent( - th.distributions.Normal(mean, cov_cholesky), 1) - elif isinstance(p, th.distributions.MultivariateNormal): - p_out = orig_p.__class__(orig_p.action_dim) - p_out.distribution = th.distributions.MultivariateNormal( - mean, scale_tril=cov_cholesky) - else: - raise Exception('Dist-Type not implemented (of sb3 dist)') - return p_out - -def entropy_inequality_projection(p: th.distributions.Normal, - beta: Union[float, th.Tensor]): - """ - Projects std to satisfy an entropy INEQUALITY constraint. - Args: - p: current distribution - beta: target entropy for EACH std or general bound for all stds - - Returns: - projected std that satisfies the entropy bound - """ - mean, std = p.mean, p.stddev - k = std.shape[-1] - batch_shape = std.shape[:-2] - - ent = p.entropy() - mask = ent < beta - - # if nothing has to be projected skip computation - if (~mask).all(): - return p - - alpha = th.ones(batch_shape, dtype=std.dtype, device=std.device) - alpha[mask] = th.exp((beta[mask] - ent[mask]) / k) - - proj_std = th.einsum('ijk,i->ijk', std, alpha) - new_mean, new_std = mean, th.where(mask[..., None, None], proj_std, std) - return th.distributions.Normal(new_mean, new_std) - - -def entropy_equality_projection(p: th.distributions.Normal, - beta: Union[float, th.Tensor]): - """ - Projects std to satisfy an entropy EQUALITY constraint. - Args: - p: current distribution - beta: target entropy for EACH std or general bound for all stds - - Returns: - projected std that satisfies the entropy bound - """ - mean, std = p.mean, p.stddev - k = std.shape[-1] - - ent = p.entropy() - alpha = th.exp((beta - ent) / k) - proj_std = th.einsum('ijk,i->ijk', std, alpha) - new_mean, new_std = mean, proj_std - return th.distributions.Normal(new_mean, new_std) - - -def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor): - """ - Projects the mean based on the Mahalanobis objective and trust region. - Args: - mean: current mean vectors - old_mean: old mean vectors - maha: Mahalanobis distance between the two mean vectors - eps: trust region bound - - Returns: - projected mean that satisfies the trust region - """ - batch_shape = mean.shape[:-1] - mask = maha > eps - - ################################################################################################################ - # mean projection maha - - # if nothing has to be projected skip computation - if mask.any(): - omega = th.ones(batch_shape, dtype=mean.dtype, device=mean.device) - omega[mask] = th.sqrt(maha[mask] / eps) - 1. - omega = th.max(-omega, omega)[..., None] - - m = (mean + omega * old_mean) / (1 + omega + 1e-16) - proj_mean = th.where(mask[..., None], m, mean) - else: - proj_mean = mean - - return proj_mean - - -def mean_equality_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor): - """ - Projections the mean based on the Mahalanobis objective and trust region for an EQUALITY constraint. - Args: - mean: current mean vectors - old_mean: old mean vectors - maha: Mahalanobis distance between the two mean vectors - eps: trust region bound - Returns: - projected mean that satisfies the trust region - """ - - maha[maha == 0] += 1e-16 - omega = th.sqrt(maha / eps) - 1. - omega = omega[..., None] - - proj_mean = (mean + omega * old_mean) / (1 + omega + 1e-16) - - return proj_mean - - -class ITPALExceptionLayer(BaseProjectionLayer): - def __init__(self, - *args, **kwargs - ): - raise Exception('To be able to use KL projections, ITPAL must be installed: https://github.com/ALRhub/ITPAL.') \ No newline at end of file diff --git a/fancy_rl/projections_legacy/frob_projection_layer.py b/fancy_rl/projections_legacy/frob_projection_layer.py deleted file mode 100644 index 2309ab2..0000000 --- a/fancy_rl/projections_legacy/frob_projection_layer.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch as th -from typing import Tuple - -from .base_projection_layer import BaseProjectionLayer, mean_projection - -from ..misc.norm import mahalanobis, frob_sq -from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like, has_diag_cov - - -class FrobeniusProjectionLayer(BaseProjectionLayer): - - def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): - """ - Stolen from Fabian's Code (Public Version) - - Runs Frobenius projection layer and constructs cholesky of covariance - - Args: - policy: policy instance - p: current distribution - q: old distribution - eps: (modified) kl bound/ kl bound for mean part - eps_cov: (modified) kl bound for cov part - beta: (modified) entropy bound - **kwargs: - Returns: mean, cov cholesky - """ - - mean, chol = get_mean_and_chol(p, expand=True) - old_mean, old_chol = get_mean_and_chol(q, expand=True) - batch_shape = mean.shape[:-1] - - #################################################################################################################### - # precompute mean and cov part of frob projection, which are used for the projection. - mean_part, cov_part, cov, cov_old = gaussian_frobenius( - p, q, self.scale_prec, True) - - ################################################################################################################ - # mean projection maha/euclidean - - proj_mean = mean_projection(mean, old_mean, mean_part, eps) - - ################################################################################################################ - # cov projection frobenius - - cov_mask = cov_part > eps_cov - - if cov_mask.any(): - eta = th.ones(batch_shape, dtype=chol.dtype, device=chol.device) - eta[cov_mask] = th.sqrt(cov_part[cov_mask] / eps_cov) - 1. - eta = th.max(-eta, eta) - - new_cov = (cov + th.einsum('i,ijk->ijk', eta, cov_old) - ) / (1. + eta + 1e-16)[..., None, None] - proj_chol = th.where( - cov_mask[..., None, None], th.linalg.cholesky(new_cov), chol) - else: - proj_chol = chol - - if has_diag_cov(p): - proj_chol = th.diagonal(proj_chol, dim1=-2, dim2=-1) - - proj_p = new_dist_like(p, proj_mean, proj_chol) - return proj_p - - def trust_region_value(self, p, q): - """ - Stolen from Fabian's Code (Public Version) - - Computes the Frobenius metric between two Gaussian distributions p and q. - Args: - policy: policy instance - p: current distribution - q: old distribution - Returns: - mean and covariance part of Frobenius metric - """ - return gaussian_frobenius(p, q, self.scale_prec) - - def get_trust_region_loss(self, p, proj_p): - """ - Stolen from Fabian's Code (Public Version) - """ - - mean_diff, _ = self.trust_region_value(p, proj_p) - if False and policy.contextual_std: - # Compute MSE here, because we found the Frobenius norm tends to generate values that explode for the cov - p_mean, proj_p_mean = p.mean, proj_p.mean - cov_diff = (p_mean - proj_p_mean).pow(2).sum([-1, -2]) - delta_loss = (mean_diff + cov_diff).mean() - else: - delta_loss = mean_diff.mean() - - return delta_loss * self.trust_region_coeff - - -def gaussian_frobenius(p, q, scale_prec: bool = False, return_cov: bool = False): - """ - Stolen from Fabian' Code (Public Version) - - Compute (p - q_values) (L_oL_o^T)^-1 (p - 1)^T + |LL^T - L_oL_o^T|_F^2 with p,q_values ~ N(y, LL^T) - Args: - policy: current policy - p: mean and chol of gaussian p - q: mean and chol of gaussian q_values - return_cov: return cov matrices for further computations - scale_prec: scale objective with precision matrix - Returns: mahalanobis distance, squared frobenius norm - """ - - mean, chol = get_mean_and_chol(p) - mean_other, chol_other = get_mean_and_chol(q) - - if scale_prec: - # maha objective for mean - mean_part = mahalanobis(mean, mean_other, chol_other) - else: - # euclidean distance for mean - # mean_part = ch.norm(mean_other - mean, ord=2, axis=1) ** 2 - mean_part = ((mean_other - mean) ** 2).sum(1) - - # frob objective for cov - cov = get_cov(p) - cov_other = get_cov(q) - diff = cov_other - cov - # Matrix is real symmetric PSD, therefore |A @ A^H|^2_F = tr{A @ A^H} = tr{A @ A} - #cov_part = torch_batched_trace(diff @ diff) - cov_part = frob_sq(diff, is_spd=True) - - if return_cov: - return mean_part, cov_part, cov, cov_other - - return mean_part, cov_part diff --git a/fancy_rl/projections_legacy/identity_projection_layer.py b/fancy_rl/projections_legacy/identity_projection_layer.py deleted file mode 100644 index 62216db..0000000 --- a/fancy_rl/projections_legacy/identity_projection_layer.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base_projection_layer import BaseProjectionLayer - -class IdentityProjectionLayer(BaseProjectionLayer): - def project_from_rollouts(self, dist, rollout_data, **kwargs): - return dist, dist diff --git a/fancy_rl/projections_legacy/kl_projection_layer.py b/fancy_rl/projections_legacy/kl_projection_layer.py deleted file mode 100644 index cc38126..0000000 --- a/fancy_rl/projections_legacy/kl_projection_layer.py +++ /dev/null @@ -1,256 +0,0 @@ -from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_cov, is_contextual, new_dist_like, has_diag_cov -from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection - -import cpp_projection -import numpy as np -import torch as th -from typing import Tuple, Any - -from ..misc.norm import mahalanobis - -MAX_EVAL = 1000 - - -class KLProjectionLayer(BaseProjectionLayer): - """ - Stolen from Fabian's Code (Private Version) - """ - - def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): - """ - Stolen from Fabian's Code (Private Version) - - runs kl projection layer and constructs sqrt of covariance - Args: - **kwargs: - policy: policy instance - p: current distribution - q: old distribution - eps: (modified) kl bound/ kl bound for mean part - eps_cov: (modified) kl bound for cov part - - Returns: - mean, cov sqrt - """ - mean, chol = get_mean_and_chol(p, expand=True) - old_mean, old_chol = get_mean_and_chol(q, expand=True) - - ################################################################################################################ - # project mean with closed form - # orig code: mean_part, _ = gaussian_kl(policy, p, q) - # But the mean_part is just the mahalanobis dist: - mean_part = mahalanobis(mean, old_mean, old_chol) - if self.mean_eq: - proj_mean = mean_equality_projection( - mean, old_mean, mean_part, eps) - else: - proj_mean = mean_projection(mean, old_mean, mean_part, eps) - - if has_diag_cov(p): - cov_diag = get_diag_cov_vec(p) - old_cov_diag = get_diag_cov_vec(q) - proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov_diag, - old_cov_diag, - eps_cov) - proj_chol = proj_cov.sqrt() # .diag_embed() - else: - cov = get_cov(p) - old_cov = get_cov(q) - proj_cov = KLProjectionGradFunctionCovOnly.apply( - cov, old_cov, chol, old_chol, eps_cov) - proj_chol = th.linalg.cholesky(proj_cov) - proj_p = new_dist_like(p, proj_mean, proj_chol) - return proj_p - - -class KLProjectionGradFunctionCovOnly(th.autograd.Function): - projection_op = None - - @staticmethod - def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL): - if not KLProjectionGradFunctionCovOnly.projection_op: - KLProjectionGradFunctionCovOnly.projection_op = \ - cpp_projection.BatchedCovOnlyProjection( - batch_shape, dim, max_eval=max_eval) - return KLProjectionGradFunctionCovOnly.projection_op - - @staticmethod - def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - #std, old_std, eps_cov = args - cov, old_cov, chol, old_chol, eps_cov = args - - batch_shape = chol.shape[0] - dim = chol.shape[-1] - - cov_np = cov.cpu().detach().numpy() - old_cov_np = old_cov.cpu().detach().numpy() - chol_np = chol.cpu().detach().numpy() - old_chol_np = old_chol.cpu().detach().numpy() - # eps = eps_cov.cpu().detach().numpy().astype(old_std_np.dtype) * \ - eps = eps_cov * \ - np.ones(batch_shape, dtype=old_chol_np.dtype) - - p_op = KLProjectionGradFunctionCovOnly.get_projection_op( - batch_shape, dim) - ctx.proj = p_op - - proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np) - - return th.Tensor(proj_cov) - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any) -> Any: - projection_op = ctx.proj - d_std, = grad_outputs - - d_std_np = d_std.cpu().detach().numpy() - d_std_np = np.atleast_2d(d_std_np) - df_stds = projection_op.backward(d_std_np) - df_stds = np.atleast_2d(df_stds) - - return d_std.new(df_stds), None, None, None, None - - -class KLProjectionGradFunctionDiagCovOnly(th.autograd.Function): - projection_op = None - - @staticmethod - def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL): - if not KLProjectionGradFunctionDiagCovOnly.projection_op: - KLProjectionGradFunctionDiagCovOnly.projection_op = \ - cpp_projection.BatchedDiagCovOnlyProjection( - batch_shape, dim, max_eval=max_eval) - return KLProjectionGradFunctionDiagCovOnly.projection_op - - @staticmethod - def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - cov, old_std_np, eps_cov = args - - batch_shape = cov.shape[0] - dim = cov.shape[-1] - - std_np = cov.to('cpu').detach().numpy() - old_std_np = old_std_np.to('cpu').detach().numpy() - # eps = eps_cov.to('cpu').detach().numpy().astype(old_std_np.dtype) * np.ones(batch_shape, dtype=old_std_np.dtype) - eps = eps_cov * np.ones(batch_shape, dtype=old_std_np.dtype) - - p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op( - batch_shape, dim) - ctx.proj = p_op - - try: - proj_std = p_op.forward(eps, old_std_np, std_np) - except: - proj_std = std_np - - return cov.new(proj_std) - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any) -> Any: - projection_op = ctx.proj - d_std, = grad_outputs - - d_std_np = d_std.to('cpu').detach().numpy() - d_std_np = np.atleast_2d(d_std_np) - df_stds = projection_op.backward(d_std_np) - df_stds = np.atleast_2d(df_stds) - - return d_std.new(df_stds), None, None - - -class KLProjectionGradFunctionDiagSplit(th.autograd.Function): - projection_op = None - - @staticmethod - def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL): - if not KLProjectionGradFunctionDiagSplit.projection_op: - KLProjectionGradFunctionDiagSplit.projection_op = \ - cpp_projection.BatchedSplitDiagMoreProjection( - batch_shape, dim, max_eval=max_eval) - return KLProjectionGradFunctionDiagSplit.projection_op - - @staticmethod - def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - mean, cov, old_mean, old_cov, eps_mu, eps_sigma = args - - batch_shape, dim = mean.shape - - mean_np = mean.detach().numpy() - cov_np = cov.detach().numpy() - old_mean = old_mean.detach().numpy() - old_cov = old_cov.detach().numpy() - eps_mu = eps_mu * np.ones(batch_shape) - eps_sigma = eps_sigma * np.ones(batch_shape) - - # p_op = cpp_projection.BatchedSplitDiagMoreProjection(batch_shape, dim, max_eval=100) - p_op = KLProjectionGradFunctionDiagSplit.get_projection_op( - batch_shape, dim) - - try: - proj_mean, proj_cov = p_op.forward( - eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np) - except Exception: - # try a second time - proj_mean, proj_cov = p_op.forward( - eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np) - ctx.proj = p_op - - return mean.new(proj_mean), cov.new(proj_cov) - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any) -> Any: - p_op = ctx.proj - d_means, d_std = grad_outputs - - d_std_np = d_std.detach().numpy() - d_std_np = np.atleast_2d(d_std_np) - d_mean_np = d_means.detach().numpy() - dtarget_means, dtarget_covs = p_op.backward(d_mean_np, d_std_np) - dtarget_covs = np.atleast_2d(dtarget_covs) - - return d_means.new(dtarget_means), d_std.new(dtarget_covs), None, None, None, None - - -class KLProjectionGradFunctionJoint(th.autograd.Function): - projection_op = None - - @staticmethod - def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL): - if not KLProjectionGradFunctionJoint.projection_op: - KLProjectionGradFunctionJoint.projection_op = \ - cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False, - max_eval=max_eval) - return KLProjectionGradFunctionJoint.projection_op - - @staticmethod - def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: - mean, cov, old_mean, old_cov, eps, beta = args - - batch_shape, dim = mean.shape - - mean_np = mean.detach().numpy() - cov_np = cov.detach().numpy() - old_mean = old_mean.detach().numpy() - old_cov = old_cov.detach().numpy() - eps = eps * np.ones(batch_shape) - beta = beta.detach().numpy() * np.ones(batch_shape) - - # projection_op = cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False) - # ctx.proj = projection_op - - p_op = KLProjectionGradFunctionJoint.get_projection_op( - batch_shape, dim) - ctx.proj = p_op - - proj_mean, proj_cov = p_op.forward( - eps, beta, old_mean, old_cov, mean_np, cov_np) - - return mean.new(proj_mean), cov.new(proj_cov) - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any) -> Any: - projection_op = ctx.proj - d_means, d_covs = grad_outputs - df_means, df_covs = projection_op.backward( - d_means.detach().numpy(), d_covs.detach().numpy()) - return d_means.new(df_means), d_means.new(df_covs), None, None, None, None diff --git a/fancy_rl/projections_legacy/w2_projection_layer.py b/fancy_rl/projections_legacy/w2_projection_layer.py deleted file mode 100644 index 7296f28..0000000 --- a/fancy_rl/projections_legacy/w2_projection_layer.py +++ /dev/null @@ -1,164 +0,0 @@ -import numpy as np -import torch as th -from typing import Tuple, Any - -from ..misc.norm import mahalanobis - -from .base_projection_layer import BaseProjectionLayer, mean_projection - -from ..misc.norm import mahalanobis, _batch_trace -from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, has_diag_cov - -from stable_baselines3.common.distributions import Distribution - - -class WassersteinProjectionLayer(BaseProjectionLayer): - """ - Stolen from Fabian's Code (Public Version) - """ - - def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): - """ - Runs commutative Wasserstein projection layer and constructs sqrt of covariance - Args: - policy: policy instance - p: current distribution - q: old distribution - eps: (modified) kl bound/ kl bound for mean part - eps_cov: (modified) kl bound for cov part - **kwargs: - - Returns: - mean, cov sqrt - """ - - mean, sqrt = get_mean_and_sqrt(p, expand=True) - old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True) - batch_shape = mean.shape[:-1] - - #################################################################################################################### - # precompute mean and cov part of W2, which are used for the projection. - # Both parts differ based on precision scaling. - # If activated, the mean part is the maha distance and the cov has a more complex term in the inner parenthesis. - mean_part, cov_part = gaussian_wasserstein_commutative( - p, q, self.scale_prec) - - #################################################################################################################### - # project mean (w/ or w/o precision scaling) - proj_mean = mean_projection(mean, old_mean, mean_part, eps) - - #################################################################################################################### - # project covariance (w/ or w/o precision scaling) - - cov_mask = cov_part > eps_cov - - if cov_mask.any(): - # gradient issue with ch.where, it executes both paths and gives NaN gradient. - eta = th.ones(batch_shape, dtype=sqrt.dtype, device=sqrt.device) - eta[cov_mask] = th.sqrt(cov_part[cov_mask] / eps_cov) - 1. - eta = th.max(-eta, eta) - - new_sqrt = (sqrt + th.einsum('i,ijk->ijk', eta, old_sqrt) - ) / (1. + eta + 1e-16)[..., None, None] - proj_sqrt = th.where(cov_mask[..., None, None], new_sqrt, sqrt) - else: - proj_sqrt = sqrt - - if has_diag_cov(p): - proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1) - - proj_p = self.new_dist_like(p, proj_mean, proj_sqrt) - return proj_p - - def trust_region_value(self, p, q): - """ - Computes the Wasserstein distance between two Gaussian distributions p and q. - Args: - policy: policy instance - p: current distribution - q: old distribution - Returns: - mean and covariance part of Wasserstein distance - """ - mean_part, cov_part = gaussian_wasserstein_commutative( - p, q, scale_prec=self.scale_prec) - return mean_part + cov_part - - def get_trust_region_loss(self, p, proj_p): - # p: - # predicted distribution from network output - # proj_p: - # projected distribution - - proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p) - p_target = self.new_dist_like(p, proj_mean, proj_sqrt) - kl_diff = self.trust_region_value(p, p_target) - - kl_loss = kl_diff.mean() - - return kl_loss * self.trust_region_coeff - - def new_dist_like(self, orig_p, mean, cov_sqrt): - assert isinstance(orig_p, Distribution) - p = orig_p.distribution - if isinstance(p, th.distributions.Normal): - p_out = orig_p.__class__(orig_p.action_dim) - p_out.distribution = th.distributions.Normal(mean, cov_sqrt) - elif isinstance(p, th.distributions.Independent): - p_out = orig_p.__class__(orig_p.action_dim) - p_out.distribution = th.distributions.Independent( - th.distributions.Normal(mean, cov_sqrt), 1) - elif isinstance(p, th.distributions.MultivariateNormal): - p_out = orig_p.__class__(orig_p.action_dim) - p_out.distribution = th.distributions.MultivariateNormal( - mean, scale_tril=cov_sqrt, validate_args=False) - else: - raise Exception('Dist-Type not implemented (of sb3 dist)') - p_out.cov_sqrt = cov_sqrt - return p_out - - -def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor, th.Tensor]: - """ - Compute mean part and cov part of W_2(p || q_values) with p,q_values ~ N(y, SS). - This version DOES assume commutativity of both distributions, i.e. covariance matrices. - This is less general and assumes both distributions are somewhat close together. - When scale_prec is true scale both distributions with old precision matrix. - Args: - policy: current policy - p: mean and sqrt of gaussian p - q: mean and sqrt of gaussian q_values - scale_prec: scale objective by old precision matrix. - This penalizes directions based on old uncertainty/covariance. - Returns: mean part of W2, cov part of W2 - """ - mean, sqrt = get_mean_and_sqrt(p, expand=True) - mean_other, sqrt_other = get_mean_and_sqrt(q, expand=True) - - if scale_prec: - # maha objective for mean - mean_part = mahalanobis(mean, mean_other, sqrt_other) - else: - # euclidean distance for mean - # mean_part = ch.norm(mean_other - mean, ord=2, axis=1) ** 2 - mean_part = ((mean_other - mean) ** 2).sum(1) - - cov = get_cov(p) - if scale_prec and False: - # cov constraint scaled with precision of old dist - batch_dim, dim = mean.shape - - identity = th.eye(dim, dtype=sqrt.dtype, device=sqrt.device) - sqrt_inv_other = th.linalg.solve(sqrt_other, identity) - c = sqrt_inv_other @ cov @ sqrt_inv_other - - cov_part = _batch_trace( - identity + c - 2 * sqrt_inv_other @ sqrt) - - else: - # W2 objective for cov assuming normal W2 objective for mean - cov_other = get_cov(q) - cov_part = _batch_trace( - cov_other + cov - 2 * th.bmm(sqrt_other, sqrt)) - - return mean_part, cov_part