102 lines
3.7 KiB
Python
102 lines
3.7 KiB
Python
|
import cpp_projection
|
||
|
import numpy as np
|
||
|
import torch as ch
|
||
|
from typing import Any, Tuple
|
||
|
|
||
|
from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
|
||
|
from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer, mean_projection
|
||
|
from trust_region_projections.utils.projection_utils import gaussian_kl
|
||
|
from trust_region_projections.utils.torch_utils import get_numpy
|
||
|
|
||
|
|
||
|
class KLProjectionLayer(BaseProjectionLayer):
|
||
|
|
||
|
def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
|
||
|
q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs):
|
||
|
"""
|
||
|
Runs KL projection layer and constructs cholesky of covariance
|
||
|
Args:
|
||
|
policy: policy instance
|
||
|
p: current distribution
|
||
|
q: old distribution
|
||
|
eps: (modified) kl bound/ kl bound for mean part
|
||
|
eps_cov: (modified) kl bound for cov part
|
||
|
**kwargs:
|
||
|
|
||
|
Returns:
|
||
|
projected mean, projected cov cholesky
|
||
|
"""
|
||
|
mean, std = p
|
||
|
old_mean, old_std = q
|
||
|
|
||
|
if not policy.contextual_std:
|
||
|
# only project first one to reduce number of numerical optimizations
|
||
|
std = std[:1]
|
||
|
old_std = old_std[:1]
|
||
|
|
||
|
################################################################################################################
|
||
|
# project mean with closed form
|
||
|
mean_part, _ = gaussian_kl(policy, p, q)
|
||
|
proj_mean = mean_projection(mean, old_mean, mean_part, eps)
|
||
|
|
||
|
cov = policy.covariance(std)
|
||
|
old_cov = policy.covariance(old_std)
|
||
|
|
||
|
if policy.is_diag:
|
||
|
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
|
||
|
old_cov.diagonal(dim1=-2, dim2=-1),
|
||
|
eps_cov)
|
||
|
proj_std = proj_cov.sqrt().diag_embed()
|
||
|
else:
|
||
|
raise NotImplementedError("The KL projection currently does not support full covariance matrices.")
|
||
|
|
||
|
if not policy.contextual_std:
|
||
|
# scale first std back to batchsize
|
||
|
proj_std = proj_std.expand(mean.shape[0], -1, -1)
|
||
|
|
||
|
return proj_mean, proj_std
|
||
|
|
||
|
|
||
|
class KLProjectionGradFunctionDiagCovOnly(ch.autograd.Function):
|
||
|
projection_op = None
|
||
|
|
||
|
@staticmethod
|
||
|
def get_projection_op(batch_shape, dim, max_eval=100):
|
||
|
if not KLProjectionGradFunctionDiagCovOnly.projection_op:
|
||
|
KLProjectionGradFunctionDiagCovOnly.projection_op = \
|
||
|
cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
|
||
|
return KLProjectionGradFunctionDiagCovOnly.projection_op
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||
|
std, old_std, eps_cov = args
|
||
|
|
||
|
batch_shape = std.shape[0]
|
||
|
dim = std.shape[-1]
|
||
|
|
||
|
cov_np = get_numpy(std)
|
||
|
old_std = get_numpy(old_std)
|
||
|
eps = get_numpy(eps_cov) * np.ones(batch_shape)
|
||
|
|
||
|
# p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim)
|
||
|
# ctx.proj = projection_op
|
||
|
|
||
|
p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim)
|
||
|
ctx.proj = p_op
|
||
|
|
||
|
proj_std = p_op.forward(eps, old_std, cov_np)
|
||
|
|
||
|
return std.new(proj_std)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||
|
projection_op = ctx.proj
|
||
|
d_std, = grad_outputs
|
||
|
|
||
|
d_std_np = get_numpy(d_std)
|
||
|
d_std_np = np.atleast_2d(d_std_np)
|
||
|
df_stds = projection_op.backward(d_std_np)
|
||
|
df_stds = np.atleast_2d(df_stds)
|
||
|
|
||
|
return d_std.new(df_stds), None, None
|