import cpp_projection import numpy as np import torch as ch from typing import Any, Tuple from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer, mean_projection from trust_region_projections.utils.projection_utils import gaussian_kl from trust_region_projections.utils.torch_utils import get_numpy class KLProjectionLayer(BaseProjectionLayer): def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor], q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs): """ Runs KL projection layer and constructs cholesky of covariance Args: policy: policy instance p: current distribution q: old distribution eps: (modified) kl bound/ kl bound for mean part eps_cov: (modified) kl bound for cov part **kwargs: Returns: projected mean, projected cov cholesky """ mean, std = p old_mean, old_std = q if not policy.contextual_std: # only project first one to reduce number of numerical optimizations std = std[:1] old_std = old_std[:1] ################################################################################################################ # project mean with closed form mean_part, _ = gaussian_kl(policy, p, q) proj_mean = mean_projection(mean, old_mean, mean_part, eps) cov = policy.covariance(std) old_cov = policy.covariance(old_std) if policy.is_diag: proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1), old_cov.diagonal(dim1=-2, dim2=-1), eps_cov) proj_std = proj_cov.sqrt().diag_embed() else: raise NotImplementedError("The KL projection currently does not support full covariance matrices.") if not policy.contextual_std: # scale first std back to batchsize proj_std = proj_std.expand(mean.shape[0], -1, -1) return proj_mean, proj_std class KLProjectionGradFunctionDiagCovOnly(ch.autograd.Function): projection_op = None @staticmethod def get_projection_op(batch_shape, dim, max_eval=100): if not KLProjectionGradFunctionDiagCovOnly.projection_op: KLProjectionGradFunctionDiagCovOnly.projection_op = \ cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval) return KLProjectionGradFunctionDiagCovOnly.projection_op @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: std, old_std, eps_cov = args batch_shape = std.shape[0] dim = std.shape[-1] cov_np = get_numpy(std) old_std = get_numpy(old_std) eps = get_numpy(eps_cov) * np.ones(batch_shape) # p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim) # ctx.proj = projection_op p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim) ctx.proj = p_op proj_std = p_op.forward(eps, old_std, cov_np) return std.new(proj_std) @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: projection_op = ctx.proj d_std, = grad_outputs d_std_np = get_numpy(d_std) d_std_np = np.atleast_2d(d_std_np) df_stds = projection_op.backward(d_std_np) df_stds = np.atleast_2d(df_stds) return d_std.new(df_stds), None, None