itpal_jax/itpal_jax/base_projection.py

76 lines
3.2 KiB
Python

from abc import ABC, abstractmethod
from typing import Dict
import jax
import jax.numpy as jnp
from functools import partial
class BaseProjection(ABC):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
self.trust_region_coeff = trust_region_coeff
self.mean_bound = mean_bound
self.cov_bound = cov_bound
self.contextual_std = contextual_std
self.full_cov = full_cov
@abstractmethod
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
"""Project parameters to satisfy trust region constraints."""
raise NotImplementedError
@abstractmethod
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Compute trust region loss between original and projected parameters."""
raise NotImplementedError
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
mean_part: jnp.ndarray) -> jnp.ndarray:
"""Project mean based on the Mahalanobis objective and trust region.
Args:
mean: Current mean vectors
old_mean: Old mean vectors
mean_part: Mahalanobis/Euclidean distance between the two mean vectors
Returns:
Projected mean that satisfies the trust region
"""
mask = mean_part > self.mean_bound
# If nothing needs to be projected, skip computation
if not jnp.any(mask):
return mean
# Compute projection factor
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
omega = jnp.where(mask,
jnp.sqrt(mean_part / self.mean_bound) - 1.,
omega)
omega = jnp.maximum(-omega, omega)[..., None]
# Project mean
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
return jnp.where(mask[..., None], m, mean)
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
cov_part: jnp.ndarray) -> jnp.ndarray:
"""Project covariance parameters."""
raise NotImplementedError
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Convert scale representation to covariance matrix."""
if not self.full_cov:
scale = params["scale"] # standard deviations
return jnp.square(scale) # diagonal covariance
else:
scale_tril = params["scale_tril"] # Cholesky factor
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
def _calc_scale_from_cov(self, cov: jnp.ndarray) -> jnp.ndarray:
"""Convert covariance matrix back to appropriate scale representation."""
if not self.full_cov:
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) # standard deviations
else:
return jnp.linalg.cholesky(cov) # Cholesky factor