76 lines
3.2 KiB
Python
76 lines
3.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Dict
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from functools import partial
|
|
|
|
class BaseProjection(ABC):
|
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
|
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
|
|
self.trust_region_coeff = trust_region_coeff
|
|
self.mean_bound = mean_bound
|
|
self.cov_bound = cov_bound
|
|
self.contextual_std = contextual_std
|
|
self.full_cov = full_cov
|
|
|
|
@abstractmethod
|
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
|
"""Project parameters to satisfy trust region constraints."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
|
"""Compute trust region loss between original and projected parameters."""
|
|
raise NotImplementedError
|
|
|
|
def _mean_projection(self, mean: jnp.ndarray, old_mean: jnp.ndarray,
|
|
mean_part: jnp.ndarray) -> jnp.ndarray:
|
|
"""Project mean based on the Mahalanobis objective and trust region.
|
|
|
|
Args:
|
|
mean: Current mean vectors
|
|
old_mean: Old mean vectors
|
|
mean_part: Mahalanobis/Euclidean distance between the two mean vectors
|
|
|
|
Returns:
|
|
Projected mean that satisfies the trust region
|
|
"""
|
|
mask = mean_part > self.mean_bound
|
|
|
|
# If nothing needs to be projected, skip computation
|
|
if not jnp.any(mask):
|
|
return mean
|
|
|
|
# Compute projection factor
|
|
omega = jnp.ones(mean_part.shape, dtype=mean.dtype)
|
|
omega = jnp.where(mask,
|
|
jnp.sqrt(mean_part / self.mean_bound) - 1.,
|
|
omega)
|
|
omega = jnp.maximum(-omega, omega)[..., None]
|
|
|
|
# Project mean
|
|
m = (mean + omega * old_mean) / (1. + omega + 1e-16)
|
|
return jnp.where(mask[..., None], m, mean)
|
|
|
|
def _cov_projection(self, scale_or_tril: jnp.ndarray, old_scale_or_tril: jnp.ndarray,
|
|
cov_part: jnp.ndarray) -> jnp.ndarray:
|
|
"""Project covariance parameters."""
|
|
raise NotImplementedError
|
|
|
|
def _calc_covariance(self, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
|
"""Convert scale representation to covariance matrix."""
|
|
if not self.full_cov:
|
|
scale = params["scale"] # standard deviations
|
|
return jnp.square(scale) # diagonal covariance
|
|
else:
|
|
scale_tril = params["scale_tril"] # Cholesky factor
|
|
return jnp.matmul(scale_tril, jnp.swapaxes(scale_tril, -1, -2))
|
|
|
|
def _calc_scale_from_cov(self, cov: jnp.ndarray) -> jnp.ndarray:
|
|
"""Convert covariance matrix back to appropriate scale representation."""
|
|
if not self.full_cov:
|
|
return jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) # standard deviations
|
|
else:
|
|
return jnp.linalg.cholesky(cov) # Cholesky factor |