Initial code fro projections
This commit is contained in:
parent
add8e92b4a
commit
78d79cf705
6
fancy_rl/projections/__init__.py
Normal file
6
fancy_rl/projections/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
try:
|
||||||
|
import cpp_projection
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
|
||||||
|
else:
|
||||||
|
from .kl_projection_layer import KLProjectionLayer
|
191
fancy_rl/projections/base_projection_layer.py
Normal file
191
fancy_rl/projections/base_projection_layer.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
from typing import Any, Dict, Optional, Type, Union, Tuple, final
|
||||||
|
|
||||||
|
import torch as th
|
||||||
|
|
||||||
|
from fancy_rl.norm import *
|
||||||
|
|
||||||
|
class BaseProjectionLayer(object):
|
||||||
|
def __init__(self,
|
||||||
|
mean_bound: float = 0.03,
|
||||||
|
cov_bound: float = 1e-3,
|
||||||
|
trust_region_coeff: float = 1.0,
|
||||||
|
scale_prec: bool = False,
|
||||||
|
):
|
||||||
|
self.mean_bound = mean_bound
|
||||||
|
self.cov_bound = cov_bound
|
||||||
|
self.trust_region_coeff = trust_region_coeff
|
||||||
|
self.scale_prec = scale_prec
|
||||||
|
self.mean_eq = False
|
||||||
|
|
||||||
|
def __call__(self, p, q, **kwargs):
|
||||||
|
return self._projection(p, q, eps=self.mean_bound, eps_cov=self.cov_bound, beta=None, **kwargs)
|
||||||
|
|
||||||
|
@final
|
||||||
|
def _projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, beta: th.Tensor, **kwargs):
|
||||||
|
return self._trust_region_projection(
|
||||||
|
p, q, eps, eps_cov, **kwargs)
|
||||||
|
|
||||||
|
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||||
|
"""
|
||||||
|
Hook for implementing the specific trust region projection
|
||||||
|
Args:
|
||||||
|
p: current distribution
|
||||||
|
q: old distribution
|
||||||
|
eps: mean trust region bound
|
||||||
|
eps_cov: covariance trust region bound
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
projected
|
||||||
|
"""
|
||||||
|
return p
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, p, proj_p):
|
||||||
|
# p:
|
||||||
|
# predicted distribution from network output
|
||||||
|
# proj_p:
|
||||||
|
# projected distribution
|
||||||
|
|
||||||
|
proj_mean, proj_chol = get_mean_and_chol(proj_p)
|
||||||
|
p_target = new_dist_like(p, proj_mean, proj_chol)
|
||||||
|
kl_diff = self.trust_region_value(p, p_target)
|
||||||
|
|
||||||
|
kl_loss = kl_diff.mean()
|
||||||
|
|
||||||
|
return kl_loss * self.trust_region_coeff
|
||||||
|
|
||||||
|
def trust_region_value(self, p, q):
|
||||||
|
"""
|
||||||
|
Computes the KL divergence between two Gaussian distributions p and q_values.
|
||||||
|
Returns:
|
||||||
|
full kl divergence
|
||||||
|
"""
|
||||||
|
return kl_divergence(p, q)
|
||||||
|
|
||||||
|
def new_dist_like(self, orig_p, mean, cov_cholesky):
|
||||||
|
assert isinstance(orig_p, Distribution)
|
||||||
|
p = orig_p.distribution
|
||||||
|
if isinstance(p, th.distributions.Normal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Normal(mean, cov_cholesky)
|
||||||
|
elif isinstance(p, th.distributions.Independent):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Independent(
|
||||||
|
th.distributions.Normal(mean, cov_cholesky), 1)
|
||||||
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.MultivariateNormal(
|
||||||
|
mean, scale_tril=cov_cholesky)
|
||||||
|
else:
|
||||||
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
||||||
|
return p_out
|
||||||
|
|
||||||
|
def entropy_inequality_projection(p: th.distributions.Normal,
|
||||||
|
beta: Union[float, th.Tensor]):
|
||||||
|
"""
|
||||||
|
Projects std to satisfy an entropy INEQUALITY constraint.
|
||||||
|
Args:
|
||||||
|
p: current distribution
|
||||||
|
beta: target entropy for EACH std or general bound for all stds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
projected std that satisfies the entropy bound
|
||||||
|
"""
|
||||||
|
mean, std = p.mean, p.stddev
|
||||||
|
k = std.shape[-1]
|
||||||
|
batch_shape = std.shape[:-2]
|
||||||
|
|
||||||
|
ent = p.entropy()
|
||||||
|
mask = ent < beta
|
||||||
|
|
||||||
|
# if nothing has to be projected skip computation
|
||||||
|
if (~mask).all():
|
||||||
|
return p
|
||||||
|
|
||||||
|
alpha = th.ones(batch_shape, dtype=std.dtype, device=std.device)
|
||||||
|
alpha[mask] = th.exp((beta[mask] - ent[mask]) / k)
|
||||||
|
|
||||||
|
proj_std = th.einsum('ijk,i->ijk', std, alpha)
|
||||||
|
new_mean, new_std = mean, th.where(mask[..., None, None], proj_std, std)
|
||||||
|
return th.distributions.Normal(new_mean, new_std)
|
||||||
|
|
||||||
|
|
||||||
|
def entropy_equality_projection(p: th.distributions.Normal,
|
||||||
|
beta: Union[float, th.Tensor]):
|
||||||
|
"""
|
||||||
|
Projects std to satisfy an entropy EQUALITY constraint.
|
||||||
|
Args:
|
||||||
|
p: current distribution
|
||||||
|
beta: target entropy for EACH std or general bound for all stds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
projected std that satisfies the entropy bound
|
||||||
|
"""
|
||||||
|
mean, std = p.mean, p.stddev
|
||||||
|
k = std.shape[-1]
|
||||||
|
|
||||||
|
ent = p.entropy()
|
||||||
|
alpha = th.exp((beta - ent) / k)
|
||||||
|
proj_std = th.einsum('ijk,i->ijk', std, alpha)
|
||||||
|
new_mean, new_std = mean, proj_std
|
||||||
|
return th.distributions.Normal(new_mean, new_std)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor):
|
||||||
|
"""
|
||||||
|
Projects the mean based on the Mahalanobis objective and trust region.
|
||||||
|
Args:
|
||||||
|
mean: current mean vectors
|
||||||
|
old_mean: old mean vectors
|
||||||
|
maha: Mahalanobis distance between the two mean vectors
|
||||||
|
eps: trust region bound
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
projected mean that satisfies the trust region
|
||||||
|
"""
|
||||||
|
batch_shape = mean.shape[:-1]
|
||||||
|
mask = maha > eps
|
||||||
|
|
||||||
|
################################################################################################################
|
||||||
|
# mean projection maha
|
||||||
|
|
||||||
|
# if nothing has to be projected skip computation
|
||||||
|
if mask.any():
|
||||||
|
omega = th.ones(batch_shape, dtype=mean.dtype, device=mean.device)
|
||||||
|
omega[mask] = th.sqrt(maha[mask] / eps) - 1.
|
||||||
|
omega = th.max(-omega, omega)[..., None]
|
||||||
|
|
||||||
|
m = (mean + omega * old_mean) / (1 + omega + 1e-16)
|
||||||
|
proj_mean = th.where(mask[..., None], m, mean)
|
||||||
|
else:
|
||||||
|
proj_mean = mean
|
||||||
|
|
||||||
|
return proj_mean
|
||||||
|
|
||||||
|
|
||||||
|
def mean_equality_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tensor, eps: th.Tensor):
|
||||||
|
"""
|
||||||
|
Projections the mean based on the Mahalanobis objective and trust region for an EQUALITY constraint.
|
||||||
|
Args:
|
||||||
|
mean: current mean vectors
|
||||||
|
old_mean: old mean vectors
|
||||||
|
maha: Mahalanobis distance between the two mean vectors
|
||||||
|
eps: trust region bound
|
||||||
|
Returns:
|
||||||
|
projected mean that satisfies the trust region
|
||||||
|
"""
|
||||||
|
|
||||||
|
maha[maha == 0] += 1e-16
|
||||||
|
omega = th.sqrt(maha / eps) - 1.
|
||||||
|
omega = omega[..., None]
|
||||||
|
|
||||||
|
proj_mean = (mean + omega * old_mean) / (1 + omega + 1e-16)
|
||||||
|
|
||||||
|
return proj_mean
|
||||||
|
|
||||||
|
|
||||||
|
class ITPALExceptionLayer(BaseProjectionLayer):
|
||||||
|
def __init__(self,
|
||||||
|
*args, **kwargs
|
||||||
|
):
|
||||||
|
raise Exception('To be able to use KL projections, ITPAL must be installed: https://github.com/ALRhub/ITPAL.')
|
133
fancy_rl/projections/frob_projection_layer.py
Normal file
133
fancy_rl/projections/frob_projection_layer.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import torch as th
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
||||||
|
|
||||||
|
from ..misc.norm import mahalanobis, frob_sq
|
||||||
|
from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like, has_diag_cov
|
||||||
|
|
||||||
|
|
||||||
|
class FrobeniusProjectionLayer(BaseProjectionLayer):
|
||||||
|
|
||||||
|
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||||
|
"""
|
||||||
|
Stolen from Fabian's Code (Public Version)
|
||||||
|
|
||||||
|
Runs Frobenius projection layer and constructs cholesky of covariance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: policy instance
|
||||||
|
p: current distribution
|
||||||
|
q: old distribution
|
||||||
|
eps: (modified) kl bound/ kl bound for mean part
|
||||||
|
eps_cov: (modified) kl bound for cov part
|
||||||
|
beta: (modified) entropy bound
|
||||||
|
**kwargs:
|
||||||
|
Returns: mean, cov cholesky
|
||||||
|
"""
|
||||||
|
|
||||||
|
mean, chol = get_mean_and_chol(p, expand=True)
|
||||||
|
old_mean, old_chol = get_mean_and_chol(q, expand=True)
|
||||||
|
batch_shape = mean.shape[:-1]
|
||||||
|
|
||||||
|
####################################################################################################################
|
||||||
|
# precompute mean and cov part of frob projection, which are used for the projection.
|
||||||
|
mean_part, cov_part, cov, cov_old = gaussian_frobenius(
|
||||||
|
p, q, self.scale_prec, True)
|
||||||
|
|
||||||
|
################################################################################################################
|
||||||
|
# mean projection maha/euclidean
|
||||||
|
|
||||||
|
proj_mean = mean_projection(mean, old_mean, mean_part, eps)
|
||||||
|
|
||||||
|
################################################################################################################
|
||||||
|
# cov projection frobenius
|
||||||
|
|
||||||
|
cov_mask = cov_part > eps_cov
|
||||||
|
|
||||||
|
if cov_mask.any():
|
||||||
|
eta = th.ones(batch_shape, dtype=chol.dtype, device=chol.device)
|
||||||
|
eta[cov_mask] = th.sqrt(cov_part[cov_mask] / eps_cov) - 1.
|
||||||
|
eta = th.max(-eta, eta)
|
||||||
|
|
||||||
|
new_cov = (cov + th.einsum('i,ijk->ijk', eta, cov_old)
|
||||||
|
) / (1. + eta + 1e-16)[..., None, None]
|
||||||
|
proj_chol = th.where(
|
||||||
|
cov_mask[..., None, None], th.linalg.cholesky(new_cov), chol)
|
||||||
|
else:
|
||||||
|
proj_chol = chol
|
||||||
|
|
||||||
|
if has_diag_cov(p):
|
||||||
|
proj_chol = th.diagonal(proj_chol, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
|
proj_p = new_dist_like(p, proj_mean, proj_chol)
|
||||||
|
return proj_p
|
||||||
|
|
||||||
|
def trust_region_value(self, p, q):
|
||||||
|
"""
|
||||||
|
Stolen from Fabian's Code (Public Version)
|
||||||
|
|
||||||
|
Computes the Frobenius metric between two Gaussian distributions p and q.
|
||||||
|
Args:
|
||||||
|
policy: policy instance
|
||||||
|
p: current distribution
|
||||||
|
q: old distribution
|
||||||
|
Returns:
|
||||||
|
mean and covariance part of Frobenius metric
|
||||||
|
"""
|
||||||
|
return gaussian_frobenius(p, q, self.scale_prec)
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, p, proj_p):
|
||||||
|
"""
|
||||||
|
Stolen from Fabian's Code (Public Version)
|
||||||
|
"""
|
||||||
|
|
||||||
|
mean_diff, _ = self.trust_region_value(p, proj_p)
|
||||||
|
if False and policy.contextual_std:
|
||||||
|
# Compute MSE here, because we found the Frobenius norm tends to generate values that explode for the cov
|
||||||
|
p_mean, proj_p_mean = p.mean, proj_p.mean
|
||||||
|
cov_diff = (p_mean - proj_p_mean).pow(2).sum([-1, -2])
|
||||||
|
delta_loss = (mean_diff + cov_diff).mean()
|
||||||
|
else:
|
||||||
|
delta_loss = mean_diff.mean()
|
||||||
|
|
||||||
|
return delta_loss * self.trust_region_coeff
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_frobenius(p, q, scale_prec: bool = False, return_cov: bool = False):
|
||||||
|
"""
|
||||||
|
Stolen from Fabian' Code (Public Version)
|
||||||
|
|
||||||
|
Compute (p - q_values) (L_oL_o^T)^-1 (p - 1)^T + |LL^T - L_oL_o^T|_F^2 with p,q_values ~ N(y, LL^T)
|
||||||
|
Args:
|
||||||
|
policy: current policy
|
||||||
|
p: mean and chol of gaussian p
|
||||||
|
q: mean and chol of gaussian q_values
|
||||||
|
return_cov: return cov matrices for further computations
|
||||||
|
scale_prec: scale objective with precision matrix
|
||||||
|
Returns: mahalanobis distance, squared frobenius norm
|
||||||
|
"""
|
||||||
|
|
||||||
|
mean, chol = get_mean_and_chol(p)
|
||||||
|
mean_other, chol_other = get_mean_and_chol(q)
|
||||||
|
|
||||||
|
if scale_prec:
|
||||||
|
# maha objective for mean
|
||||||
|
mean_part = mahalanobis(mean, mean_other, chol_other)
|
||||||
|
else:
|
||||||
|
# euclidean distance for mean
|
||||||
|
# mean_part = ch.norm(mean_other - mean, ord=2, axis=1) ** 2
|
||||||
|
mean_part = ((mean_other - mean) ** 2).sum(1)
|
||||||
|
|
||||||
|
# frob objective for cov
|
||||||
|
cov = get_cov(p)
|
||||||
|
cov_other = get_cov(q)
|
||||||
|
diff = cov_other - cov
|
||||||
|
# Matrix is real symmetric PSD, therefore |A @ A^H|^2_F = tr{A @ A^H} = tr{A @ A}
|
||||||
|
#cov_part = torch_batched_trace(diff @ diff)
|
||||||
|
cov_part = frob_sq(diff, is_spd=True)
|
||||||
|
|
||||||
|
if return_cov:
|
||||||
|
return mean_part, cov_part, cov, cov_other
|
||||||
|
|
||||||
|
return mean_part, cov_part
|
5
fancy_rl/projections/identity_projection_layer.py
Normal file
5
fancy_rl/projections/identity_projection_layer.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .base_projection_layer import BaseProjectionLayer
|
||||||
|
|
||||||
|
class IdentityProjectionLayer(BaseProjectionLayer):
|
||||||
|
def project_from_rollouts(self, dist, rollout_data, **kwargs):
|
||||||
|
return dist, dist
|
256
fancy_rl/projections/kl_projection_layer.py
Normal file
256
fancy_rl/projections/kl_projection_layer.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_cov, is_contextual, new_dist_like, has_diag_cov
|
||||||
|
from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection
|
||||||
|
|
||||||
|
import cpp_projection
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from typing import Tuple, Any
|
||||||
|
|
||||||
|
from ..misc.norm import mahalanobis
|
||||||
|
|
||||||
|
MAX_EVAL = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class KLProjectionLayer(BaseProjectionLayer):
|
||||||
|
"""
|
||||||
|
Stolen from Fabian's Code (Private Version)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||||
|
"""
|
||||||
|
Stolen from Fabian's Code (Private Version)
|
||||||
|
|
||||||
|
runs kl projection layer and constructs sqrt of covariance
|
||||||
|
Args:
|
||||||
|
**kwargs:
|
||||||
|
policy: policy instance
|
||||||
|
p: current distribution
|
||||||
|
q: old distribution
|
||||||
|
eps: (modified) kl bound/ kl bound for mean part
|
||||||
|
eps_cov: (modified) kl bound for cov part
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mean, cov sqrt
|
||||||
|
"""
|
||||||
|
mean, chol = get_mean_and_chol(p, expand=True)
|
||||||
|
old_mean, old_chol = get_mean_and_chol(q, expand=True)
|
||||||
|
|
||||||
|
################################################################################################################
|
||||||
|
# project mean with closed form
|
||||||
|
# orig code: mean_part, _ = gaussian_kl(policy, p, q)
|
||||||
|
# But the mean_part is just the mahalanobis dist:
|
||||||
|
mean_part = mahalanobis(mean, old_mean, old_chol)
|
||||||
|
if self.mean_eq:
|
||||||
|
proj_mean = mean_equality_projection(
|
||||||
|
mean, old_mean, mean_part, eps)
|
||||||
|
else:
|
||||||
|
proj_mean = mean_projection(mean, old_mean, mean_part, eps)
|
||||||
|
|
||||||
|
if has_diag_cov(p):
|
||||||
|
cov_diag = get_diag_cov_vec(p)
|
||||||
|
old_cov_diag = get_diag_cov_vec(q)
|
||||||
|
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov_diag,
|
||||||
|
old_cov_diag,
|
||||||
|
eps_cov)
|
||||||
|
proj_chol = proj_cov.sqrt() # .diag_embed()
|
||||||
|
else:
|
||||||
|
cov = get_cov(p)
|
||||||
|
old_cov = get_cov(q)
|
||||||
|
proj_cov = KLProjectionGradFunctionCovOnly.apply(
|
||||||
|
cov, old_cov, chol, old_chol, eps_cov)
|
||||||
|
proj_chol = th.linalg.cholesky(proj_cov)
|
||||||
|
proj_p = new_dist_like(p, proj_mean, proj_chol)
|
||||||
|
return proj_p
|
||||||
|
|
||||||
|
|
||||||
|
class KLProjectionGradFunctionCovOnly(th.autograd.Function):
|
||||||
|
projection_op = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
|
||||||
|
if not KLProjectionGradFunctionCovOnly.projection_op:
|
||||||
|
KLProjectionGradFunctionCovOnly.projection_op = \
|
||||||
|
cpp_projection.BatchedCovOnlyProjection(
|
||||||
|
batch_shape, dim, max_eval=max_eval)
|
||||||
|
return KLProjectionGradFunctionCovOnly.projection_op
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
#std, old_std, eps_cov = args
|
||||||
|
cov, old_cov, chol, old_chol, eps_cov = args
|
||||||
|
|
||||||
|
batch_shape = chol.shape[0]
|
||||||
|
dim = chol.shape[-1]
|
||||||
|
|
||||||
|
cov_np = cov.cpu().detach().numpy()
|
||||||
|
old_cov_np = old_cov.cpu().detach().numpy()
|
||||||
|
chol_np = chol.cpu().detach().numpy()
|
||||||
|
old_chol_np = old_chol.cpu().detach().numpy()
|
||||||
|
# eps = eps_cov.cpu().detach().numpy().astype(old_std_np.dtype) * \
|
||||||
|
eps = eps_cov * \
|
||||||
|
np.ones(batch_shape, dtype=old_chol_np.dtype)
|
||||||
|
|
||||||
|
p_op = KLProjectionGradFunctionCovOnly.get_projection_op(
|
||||||
|
batch_shape, dim)
|
||||||
|
ctx.proj = p_op
|
||||||
|
|
||||||
|
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
|
||||||
|
|
||||||
|
return th.Tensor(proj_cov)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||||
|
projection_op = ctx.proj
|
||||||
|
d_std, = grad_outputs
|
||||||
|
|
||||||
|
d_std_np = d_std.cpu().detach().numpy()
|
||||||
|
d_std_np = np.atleast_2d(d_std_np)
|
||||||
|
df_stds = projection_op.backward(d_std_np)
|
||||||
|
df_stds = np.atleast_2d(df_stds)
|
||||||
|
|
||||||
|
return d_std.new(df_stds), None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class KLProjectionGradFunctionDiagCovOnly(th.autograd.Function):
|
||||||
|
projection_op = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL):
|
||||||
|
if not KLProjectionGradFunctionDiagCovOnly.projection_op:
|
||||||
|
KLProjectionGradFunctionDiagCovOnly.projection_op = \
|
||||||
|
cpp_projection.BatchedDiagCovOnlyProjection(
|
||||||
|
batch_shape, dim, max_eval=max_eval)
|
||||||
|
return KLProjectionGradFunctionDiagCovOnly.projection_op
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
cov, old_std_np, eps_cov = args
|
||||||
|
|
||||||
|
batch_shape = cov.shape[0]
|
||||||
|
dim = cov.shape[-1]
|
||||||
|
|
||||||
|
std_np = cov.to('cpu').detach().numpy()
|
||||||
|
old_std_np = old_std_np.to('cpu').detach().numpy()
|
||||||
|
# eps = eps_cov.to('cpu').detach().numpy().astype(old_std_np.dtype) * np.ones(batch_shape, dtype=old_std_np.dtype)
|
||||||
|
eps = eps_cov * np.ones(batch_shape, dtype=old_std_np.dtype)
|
||||||
|
|
||||||
|
p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(
|
||||||
|
batch_shape, dim)
|
||||||
|
ctx.proj = p_op
|
||||||
|
|
||||||
|
try:
|
||||||
|
proj_std = p_op.forward(eps, old_std_np, std_np)
|
||||||
|
except:
|
||||||
|
proj_std = std_np
|
||||||
|
|
||||||
|
return cov.new(proj_std)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||||
|
projection_op = ctx.proj
|
||||||
|
d_std, = grad_outputs
|
||||||
|
|
||||||
|
d_std_np = d_std.to('cpu').detach().numpy()
|
||||||
|
d_std_np = np.atleast_2d(d_std_np)
|
||||||
|
df_stds = projection_op.backward(d_std_np)
|
||||||
|
df_stds = np.atleast_2d(df_stds)
|
||||||
|
|
||||||
|
return d_std.new(df_stds), None, None
|
||||||
|
|
||||||
|
|
||||||
|
class KLProjectionGradFunctionDiagSplit(th.autograd.Function):
|
||||||
|
projection_op = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL):
|
||||||
|
if not KLProjectionGradFunctionDiagSplit.projection_op:
|
||||||
|
KLProjectionGradFunctionDiagSplit.projection_op = \
|
||||||
|
cpp_projection.BatchedSplitDiagMoreProjection(
|
||||||
|
batch_shape, dim, max_eval=max_eval)
|
||||||
|
return KLProjectionGradFunctionDiagSplit.projection_op
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
mean, cov, old_mean, old_cov, eps_mu, eps_sigma = args
|
||||||
|
|
||||||
|
batch_shape, dim = mean.shape
|
||||||
|
|
||||||
|
mean_np = mean.detach().numpy()
|
||||||
|
cov_np = cov.detach().numpy()
|
||||||
|
old_mean = old_mean.detach().numpy()
|
||||||
|
old_cov = old_cov.detach().numpy()
|
||||||
|
eps_mu = eps_mu * np.ones(batch_shape)
|
||||||
|
eps_sigma = eps_sigma * np.ones(batch_shape)
|
||||||
|
|
||||||
|
# p_op = cpp_projection.BatchedSplitDiagMoreProjection(batch_shape, dim, max_eval=100)
|
||||||
|
p_op = KLProjectionGradFunctionDiagSplit.get_projection_op(
|
||||||
|
batch_shape, dim)
|
||||||
|
|
||||||
|
try:
|
||||||
|
proj_mean, proj_cov = p_op.forward(
|
||||||
|
eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np)
|
||||||
|
except Exception:
|
||||||
|
# try a second time
|
||||||
|
proj_mean, proj_cov = p_op.forward(
|
||||||
|
eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np)
|
||||||
|
ctx.proj = p_op
|
||||||
|
|
||||||
|
return mean.new(proj_mean), cov.new(proj_cov)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||||
|
p_op = ctx.proj
|
||||||
|
d_means, d_std = grad_outputs
|
||||||
|
|
||||||
|
d_std_np = d_std.detach().numpy()
|
||||||
|
d_std_np = np.atleast_2d(d_std_np)
|
||||||
|
d_mean_np = d_means.detach().numpy()
|
||||||
|
dtarget_means, dtarget_covs = p_op.backward(d_mean_np, d_std_np)
|
||||||
|
dtarget_covs = np.atleast_2d(dtarget_covs)
|
||||||
|
|
||||||
|
return d_means.new(dtarget_means), d_std.new(dtarget_covs), None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class KLProjectionGradFunctionJoint(th.autograd.Function):
|
||||||
|
projection_op = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL):
|
||||||
|
if not KLProjectionGradFunctionJoint.projection_op:
|
||||||
|
KLProjectionGradFunctionJoint.projection_op = \
|
||||||
|
cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False,
|
||||||
|
max_eval=max_eval)
|
||||||
|
return KLProjectionGradFunctionJoint.projection_op
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
mean, cov, old_mean, old_cov, eps, beta = args
|
||||||
|
|
||||||
|
batch_shape, dim = mean.shape
|
||||||
|
|
||||||
|
mean_np = mean.detach().numpy()
|
||||||
|
cov_np = cov.detach().numpy()
|
||||||
|
old_mean = old_mean.detach().numpy()
|
||||||
|
old_cov = old_cov.detach().numpy()
|
||||||
|
eps = eps * np.ones(batch_shape)
|
||||||
|
beta = beta.detach().numpy() * np.ones(batch_shape)
|
||||||
|
|
||||||
|
# projection_op = cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False)
|
||||||
|
# ctx.proj = projection_op
|
||||||
|
|
||||||
|
p_op = KLProjectionGradFunctionJoint.get_projection_op(
|
||||||
|
batch_shape, dim)
|
||||||
|
ctx.proj = p_op
|
||||||
|
|
||||||
|
proj_mean, proj_cov = p_op.forward(
|
||||||
|
eps, beta, old_mean, old_cov, mean_np, cov_np)
|
||||||
|
|
||||||
|
return mean.new(proj_mean), cov.new(proj_cov)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||||
|
projection_op = ctx.proj
|
||||||
|
d_means, d_covs = grad_outputs
|
||||||
|
df_means, df_covs = projection_op.backward(
|
||||||
|
d_means.detach().numpy(), d_covs.detach().numpy())
|
||||||
|
return d_means.new(df_means), d_means.new(df_covs), None, None, None, None
|
164
fancy_rl/projections/w2_projection_layer.py
Normal file
164
fancy_rl/projections/w2_projection_layer.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from typing import Tuple, Any
|
||||||
|
|
||||||
|
from ..misc.norm import mahalanobis
|
||||||
|
|
||||||
|
from .base_projection_layer import BaseProjectionLayer, mean_projection
|
||||||
|
|
||||||
|
from ..misc.norm import mahalanobis, _batch_trace
|
||||||
|
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, has_diag_cov
|
||||||
|
|
||||||
|
from stable_baselines3.common.distributions import Distribution
|
||||||
|
|
||||||
|
|
||||||
|
class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||||
|
"""
|
||||||
|
Stolen from Fabian's Code (Public Version)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||||
|
"""
|
||||||
|
Runs commutative Wasserstein projection layer and constructs sqrt of covariance
|
||||||
|
Args:
|
||||||
|
policy: policy instance
|
||||||
|
p: current distribution
|
||||||
|
q: old distribution
|
||||||
|
eps: (modified) kl bound/ kl bound for mean part
|
||||||
|
eps_cov: (modified) kl bound for cov part
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mean, cov sqrt
|
||||||
|
"""
|
||||||
|
|
||||||
|
mean, sqrt = get_mean_and_sqrt(p, expand=True)
|
||||||
|
old_mean, old_sqrt = get_mean_and_sqrt(q, expand=True)
|
||||||
|
batch_shape = mean.shape[:-1]
|
||||||
|
|
||||||
|
####################################################################################################################
|
||||||
|
# precompute mean and cov part of W2, which are used for the projection.
|
||||||
|
# Both parts differ based on precision scaling.
|
||||||
|
# If activated, the mean part is the maha distance and the cov has a more complex term in the inner parenthesis.
|
||||||
|
mean_part, cov_part = gaussian_wasserstein_commutative(
|
||||||
|
p, q, self.scale_prec)
|
||||||
|
|
||||||
|
####################################################################################################################
|
||||||
|
# project mean (w/ or w/o precision scaling)
|
||||||
|
proj_mean = mean_projection(mean, old_mean, mean_part, eps)
|
||||||
|
|
||||||
|
####################################################################################################################
|
||||||
|
# project covariance (w/ or w/o precision scaling)
|
||||||
|
|
||||||
|
cov_mask = cov_part > eps_cov
|
||||||
|
|
||||||
|
if cov_mask.any():
|
||||||
|
# gradient issue with ch.where, it executes both paths and gives NaN gradient.
|
||||||
|
eta = th.ones(batch_shape, dtype=sqrt.dtype, device=sqrt.device)
|
||||||
|
eta[cov_mask] = th.sqrt(cov_part[cov_mask] / eps_cov) - 1.
|
||||||
|
eta = th.max(-eta, eta)
|
||||||
|
|
||||||
|
new_sqrt = (sqrt + th.einsum('i,ijk->ijk', eta, old_sqrt)
|
||||||
|
) / (1. + eta + 1e-16)[..., None, None]
|
||||||
|
proj_sqrt = th.where(cov_mask[..., None, None], new_sqrt, sqrt)
|
||||||
|
else:
|
||||||
|
proj_sqrt = sqrt
|
||||||
|
|
||||||
|
if has_diag_cov(p):
|
||||||
|
proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
|
proj_p = self.new_dist_like(p, proj_mean, proj_sqrt)
|
||||||
|
return proj_p
|
||||||
|
|
||||||
|
def trust_region_value(self, p, q):
|
||||||
|
"""
|
||||||
|
Computes the Wasserstein distance between two Gaussian distributions p and q.
|
||||||
|
Args:
|
||||||
|
policy: policy instance
|
||||||
|
p: current distribution
|
||||||
|
q: old distribution
|
||||||
|
Returns:
|
||||||
|
mean and covariance part of Wasserstein distance
|
||||||
|
"""
|
||||||
|
mean_part, cov_part = gaussian_wasserstein_commutative(
|
||||||
|
p, q, scale_prec=self.scale_prec)
|
||||||
|
return mean_part + cov_part
|
||||||
|
|
||||||
|
def get_trust_region_loss(self, p, proj_p):
|
||||||
|
# p:
|
||||||
|
# predicted distribution from network output
|
||||||
|
# proj_p:
|
||||||
|
# projected distribution
|
||||||
|
|
||||||
|
proj_mean, proj_sqrt = get_mean_and_sqrt(proj_p)
|
||||||
|
p_target = self.new_dist_like(p, proj_mean, proj_sqrt)
|
||||||
|
kl_diff = self.trust_region_value(p, p_target)
|
||||||
|
|
||||||
|
kl_loss = kl_diff.mean()
|
||||||
|
|
||||||
|
return kl_loss * self.trust_region_coeff
|
||||||
|
|
||||||
|
def new_dist_like(self, orig_p, mean, cov_sqrt):
|
||||||
|
assert isinstance(orig_p, Distribution)
|
||||||
|
p = orig_p.distribution
|
||||||
|
if isinstance(p, th.distributions.Normal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Normal(mean, cov_sqrt)
|
||||||
|
elif isinstance(p, th.distributions.Independent):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Independent(
|
||||||
|
th.distributions.Normal(mean, cov_sqrt), 1)
|
||||||
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.MultivariateNormal(
|
||||||
|
mean, scale_tril=cov_sqrt, validate_args=False)
|
||||||
|
else:
|
||||||
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
||||||
|
p_out.cov_sqrt = cov_sqrt
|
||||||
|
return p_out
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute mean part and cov part of W_2(p || q_values) with p,q_values ~ N(y, SS).
|
||||||
|
This version DOES assume commutativity of both distributions, i.e. covariance matrices.
|
||||||
|
This is less general and assumes both distributions are somewhat close together.
|
||||||
|
When scale_prec is true scale both distributions with old precision matrix.
|
||||||
|
Args:
|
||||||
|
policy: current policy
|
||||||
|
p: mean and sqrt of gaussian p
|
||||||
|
q: mean and sqrt of gaussian q_values
|
||||||
|
scale_prec: scale objective by old precision matrix.
|
||||||
|
This penalizes directions based on old uncertainty/covariance.
|
||||||
|
Returns: mean part of W2, cov part of W2
|
||||||
|
"""
|
||||||
|
mean, sqrt = get_mean_and_sqrt(p, expand=True)
|
||||||
|
mean_other, sqrt_other = get_mean_and_sqrt(q, expand=True)
|
||||||
|
|
||||||
|
if scale_prec:
|
||||||
|
# maha objective for mean
|
||||||
|
mean_part = mahalanobis(mean, mean_other, sqrt_other)
|
||||||
|
else:
|
||||||
|
# euclidean distance for mean
|
||||||
|
# mean_part = ch.norm(mean_other - mean, ord=2, axis=1) ** 2
|
||||||
|
mean_part = ((mean_other - mean) ** 2).sum(1)
|
||||||
|
|
||||||
|
cov = get_cov(p)
|
||||||
|
if scale_prec and False:
|
||||||
|
# cov constraint scaled with precision of old dist
|
||||||
|
batch_dim, dim = mean.shape
|
||||||
|
|
||||||
|
identity = th.eye(dim, dtype=sqrt.dtype, device=sqrt.device)
|
||||||
|
sqrt_inv_other = th.linalg.solve(sqrt_other, identity)
|
||||||
|
c = sqrt_inv_other @ cov @ sqrt_inv_other
|
||||||
|
|
||||||
|
cov_part = _batch_trace(
|
||||||
|
identity + c - 2 * sqrt_inv_other @ sqrt)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# W2 objective for cov assuming normal W2 objective for mean
|
||||||
|
cov_other = get_cov(q)
|
||||||
|
cov_part = _batch_trace(
|
||||||
|
cov_other + cov - 2 * th.bmm(sqrt_other, sqrt))
|
||||||
|
|
||||||
|
return mean_part, cov_part
|
Loading…
Reference in New Issue
Block a user