itpal_jax/itpal_jax/identity_projection.py

21 lines
1006 B
Python

import jax.numpy as jnp
from .base_projection import BaseProjection
from typing import Dict
import jax
from functools import partial
class IdentityProjection(BaseProjection):
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
@partial(jax.jit, static_argnames=('self'))
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
return policy_params
@partial(jax.jit, static_argnames=('self'))
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return jnp.array(0.0)