metastable-projections/projections/kl_projection_layer.py
2022-09-03 11:22:29 +02:00

254 lines
9.2 KiB
Python

from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_cov, is_contextual, new_dist_like, has_diag_cov
from .base_projection_layer import BaseProjectionLayer, mean_projection, mean_equality_projection
import cpp_projection
import numpy as np
import torch as th
from typing import Tuple, Any
from ..misc.norm import mahalanobis
MAX_EVAL = 1000
class KLProjectionLayer(BaseProjectionLayer):
"""
Stolen from Fabian's Code (Private Version)
"""
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
"""
Stolen from Fabian's Code (Private Version)
runs kl projection layer and constructs sqrt of covariance
Args:
**kwargs:
policy: policy instance
p: current distribution
q: old distribution
eps: (modified) kl bound/ kl bound for mean part
eps_cov: (modified) kl bound for cov part
Returns:
mean, cov sqrt
"""
mean, chol = get_mean_and_chol(p, expand=True)
old_mean, old_chol = get_mean_and_chol(q, expand=True)
################################################################################################################
# project mean with closed form
# orig code: mean_part, _ = gaussian_kl(policy, p, q)
# But the mean_part is just the mahalanobis dist:
mean_part = mahalanobis(mean, old_mean, old_chol)
if self.mean_eq:
proj_mean = mean_equality_projection(
mean, old_mean, mean_part, eps)
else:
proj_mean = mean_projection(mean, old_mean, mean_part, eps)
if has_diag_cov(p):
cov_diag = get_diag_cov_vec(p)
old_cov_diag = get_diag_cov_vec(q)
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov_diag,
old_cov_diag,
eps_cov)
proj_chol = proj_cov.sqrt() # .diag_embed()
else:
cov = get_cov(p)
old_cov = get_cov(q)
proj_cov = KLProjectionGradFunctionCovOnly.apply(
cov, old_cov, chol, old_chol, eps_cov)
proj_chol = th.linalg.cholesky(proj_cov)
proj_p = new_dist_like(p, proj_mean, proj_chol)
return proj_p
class KLProjectionGradFunctionCovOnly(th.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
if not KLProjectionGradFunctionCovOnly.projection_op:
KLProjectionGradFunctionCovOnly.projection_op = \
cpp_projection.BatchedCovOnlyProjection(
batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionCovOnly.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
#std, old_std, eps_cov = args
cov, old_cov, chol, old_chol, eps_cov = args
batch_shape = chol.shape[0]
dim = chol.shape[-1]
cov_np = cov.cpu().detach().numpy()
old_cov_np = old_cov.cpu().detach().numpy()
chol_np = chol.cpu().detach().numpy()
old_chol_np = old_chol.cpu().detach().numpy()
# eps = eps_cov.cpu().detach().numpy().astype(old_std_np.dtype) * \
eps = eps_cov * \
np.ones(batch_shape, dtype=old_chol_np.dtype)
p_op = KLProjectionGradFunctionCovOnly.get_projection_op(
batch_shape, dim)
ctx.proj = p_op
proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
return th.Tensor(proj_cov)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_std, = grad_outputs
d_std_np = d_std.cpu().detach().numpy()
d_std_np = np.atleast_2d(d_std_np)
df_stds = projection_op.backward(d_std_np)
df_stds = np.atleast_2d(df_stds)
return d_std.new(df_stds), None, None, None, None
class KLProjectionGradFunctionDiagCovOnly(th.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL):
if not KLProjectionGradFunctionDiagCovOnly.projection_op:
KLProjectionGradFunctionDiagCovOnly.projection_op = \
cpp_projection.BatchedDiagCovOnlyProjection(
batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionDiagCovOnly.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
cov, old_std_np, eps_cov = args
batch_shape = cov.shape[0]
dim = cov.shape[-1]
std_np = cov.to('cpu').detach().numpy()
old_std_np = old_std_np.to('cpu').detach().numpy()
# eps = eps_cov.to('cpu').detach().numpy().astype(old_std_np.dtype) * np.ones(batch_shape, dtype=old_std_np.dtype)
eps = eps_cov * np.ones(batch_shape, dtype=old_std_np.dtype)
p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(
batch_shape, dim)
ctx.proj = p_op
proj_std = p_op.forward(eps, old_std_np, std_np)
return cov.new(proj_std)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_std, = grad_outputs
d_std_np = d_std.to('cpu').detach().numpy()
d_std_np = np.atleast_2d(d_std_np)
df_stds = projection_op.backward(d_std_np)
df_stds = np.atleast_2d(df_stds)
return d_std.new(df_stds), None, None
class KLProjectionGradFunctionDiagSplit(th.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL):
if not KLProjectionGradFunctionDiagSplit.projection_op:
KLProjectionGradFunctionDiagSplit.projection_op = \
cpp_projection.BatchedSplitDiagMoreProjection(
batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionDiagSplit.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
mean, cov, old_mean, old_cov, eps_mu, eps_sigma = args
batch_shape, dim = mean.shape
mean_np = mean.detach().numpy()
cov_np = cov.detach().numpy()
old_mean = old_mean.detach().numpy()
old_cov = old_cov.detach().numpy()
eps_mu = eps_mu * np.ones(batch_shape)
eps_sigma = eps_sigma * np.ones(batch_shape)
# p_op = cpp_projection.BatchedSplitDiagMoreProjection(batch_shape, dim, max_eval=100)
p_op = KLProjectionGradFunctionDiagSplit.get_projection_op(
batch_shape, dim)
try:
proj_mean, proj_cov = p_op.forward(
eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np)
except Exception:
# try a second time
proj_mean, proj_cov = p_op.forward(
eps_mu, eps_sigma, old_mean, old_cov, mean_np, cov_np)
ctx.proj = p_op
return mean.new(proj_mean), cov.new(proj_cov)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
p_op = ctx.proj
d_means, d_std = grad_outputs
d_std_np = d_std.detach().numpy()
d_std_np = np.atleast_2d(d_std_np)
d_mean_np = d_means.detach().numpy()
dtarget_means, dtarget_covs = p_op.backward(d_mean_np, d_std_np)
dtarget_covs = np.atleast_2d(dtarget_covs)
return d_means.new(dtarget_means), d_std.new(dtarget_covs), None, None, None, None
class KLProjectionGradFunctionJoint(th.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim: int, max_eval: int = MAX_EVAL):
if not KLProjectionGradFunctionJoint.projection_op:
KLProjectionGradFunctionJoint.projection_op = \
cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False,
max_eval=max_eval)
return KLProjectionGradFunctionJoint.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
mean, cov, old_mean, old_cov, eps, beta = args
batch_shape, dim = mean.shape
mean_np = mean.detach().numpy()
cov_np = cov.detach().numpy()
old_mean = old_mean.detach().numpy()
old_cov = old_cov.detach().numpy()
eps = eps * np.ones(batch_shape)
beta = beta.detach().numpy() * np.ones(batch_shape)
# projection_op = cpp_projection.BatchedProjection(batch_shape, dim, eec=False, constrain_entropy=False)
# ctx.proj = projection_op
p_op = KLProjectionGradFunctionJoint.get_projection_op(
batch_shape, dim)
ctx.proj = p_op
proj_mean, proj_cov = p_op.forward(
eps, beta, old_mean, old_cov, mean_np, cov_np)
return mean.new(proj_mean), cov.new(proj_cov)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_means, d_covs = grad_outputs
df_means, df_covs = projection_op.backward(
d_means.detach().numpy(), d_covs.detach().numpy())
return d_means.new(df_means), d_means.new(df_covs), None, None, None, None