metastable-baselines/projections_orig/w2_projection_layer.py

84 lines
3.7 KiB
Python
Raw Normal View History

# Copyright (c) 2021 Robert Bosch GmbH
# Author: Fabian Otto
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import torch as ch
from typing import Tuple
from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer, mean_projection
from trust_region_projections.utils.projection_utils import gaussian_wasserstein_commutative
class WassersteinProjectionLayer(BaseProjectionLayer):
def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs):
"""
Runs commutative Wasserstein projection layer and constructs sqrt of covariance
Args:
policy: policy instance
p: current distribution
q: old distribution
eps: (modified) kl bound/ kl bound for mean part
eps_cov: (modified) kl bound for cov part
**kwargs:
Returns:
mean, cov sqrt
"""
mean, sqrt = p
old_mean, old_sqrt = q
batch_shape = mean.shape[:-1]
####################################################################################################################
# precompute mean and cov part of W2, which are used for the projection.
# Both parts differ based on precision scaling.
# If activated, the mean part is the maha distance and the cov has a more complex term in the inner parenthesis.
mean_part, cov_part = gaussian_wasserstein_commutative(policy, p, q, self.scale_prec)
####################################################################################################################
# project mean (w/ or w/o precision scaling)
proj_mean = mean_projection(mean, old_mean, mean_part, eps)
####################################################################################################################
# project covariance (w/ or w/o precision scaling)
cov_mask = cov_part > eps_cov
if cov_mask.any():
# gradient issue with ch.where, it executes both paths and gives NaN gradient.
eta = ch.ones(batch_shape, dtype=sqrt.dtype, device=sqrt.device)
eta[cov_mask] = ch.sqrt(cov_part[cov_mask] / eps_cov) - 1.
eta = ch.max(-eta, eta)
new_sqrt = (sqrt + ch.einsum('i,ijk->ijk', eta, old_sqrt)) / (1. + eta + 1e-16)[..., None, None]
proj_sqrt = ch.where(cov_mask[..., None, None], new_sqrt, sqrt)
else:
proj_sqrt = sqrt
return proj_mean, proj_sqrt
def trust_region_value(self, policy, p, q):
"""
Computes the Wasserstein distance between two Gaussian distributions p and q.
Args:
policy: policy instance
p: current distribution
q: old distribution
Returns:
mean and covariance part of Wasserstein distance
"""
return gaussian_wasserstein_commutative(policy, p, q, scale_prec=self.scale_prec)