metastable-baselines/projections_orig/kl_projection_layer.py
2022-06-16 10:59:26 +02:00

102 lines
3.7 KiB
Python

import cpp_projection
import numpy as np
import torch as ch
from typing import Any, Tuple
from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer, mean_projection
from trust_region_projections.utils.projection_utils import gaussian_kl
from trust_region_projections.utils.torch_utils import get_numpy
class KLProjectionLayer(BaseProjectionLayer):
def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs):
"""
Runs KL projection layer and constructs cholesky of covariance
Args:
policy: policy instance
p: current distribution
q: old distribution
eps: (modified) kl bound/ kl bound for mean part
eps_cov: (modified) kl bound for cov part
**kwargs:
Returns:
projected mean, projected cov cholesky
"""
mean, std = p
old_mean, old_std = q
if not policy.contextual_std:
# only project first one to reduce number of numerical optimizations
std = std[:1]
old_std = old_std[:1]
################################################################################################################
# project mean with closed form
mean_part, _ = gaussian_kl(policy, p, q)
proj_mean = mean_projection(mean, old_mean, mean_part, eps)
cov = policy.covariance(std)
old_cov = policy.covariance(old_std)
if policy.is_diag:
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
old_cov.diagonal(dim1=-2, dim2=-1),
eps_cov)
proj_std = proj_cov.sqrt().diag_embed()
else:
raise NotImplementedError("The KL projection currently does not support full covariance matrices.")
if not policy.contextual_std:
# scale first std back to batchsize
proj_std = proj_std.expand(mean.shape[0], -1, -1)
return proj_mean, proj_std
class KLProjectionGradFunctionDiagCovOnly(ch.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim, max_eval=100):
if not KLProjectionGradFunctionDiagCovOnly.projection_op:
KLProjectionGradFunctionDiagCovOnly.projection_op = \
cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionDiagCovOnly.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
std, old_std, eps_cov = args
batch_shape = std.shape[0]
dim = std.shape[-1]
cov_np = get_numpy(std)
old_std = get_numpy(old_std)
eps = get_numpy(eps_cov) * np.ones(batch_shape)
# p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim)
# ctx.proj = projection_op
p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim)
ctx.proj = p_op
proj_std = p_op.forward(eps, old_std, cov_np)
return std.new(proj_std)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_std, = grad_outputs
d_std_np = get_numpy(d_std)
d_std_np = np.atleast_2d(d_std_np)
df_stds = projection_op.backward(d_std_np)
df_stds = np.atleast_2d(df_stds)
return d_std.new(df_stds), None, None