21 lines
1006 B
Python
21 lines
1006 B
Python
import jax.numpy as jnp
|
|
from .base_projection import BaseProjection
|
|
from typing import Dict
|
|
import jax
|
|
from functools import partial
|
|
|
|
class IdentityProjection(BaseProjection):
|
|
def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
|
|
cov_bound: float = 0.01, contextual_std: bool = True, full_cov: bool = False):
|
|
super().__init__(trust_region_coeff=trust_region_coeff, mean_bound=mean_bound,
|
|
cov_bound=cov_bound, contextual_std=contextual_std, full_cov=full_cov)
|
|
|
|
@partial(jax.jit, static_argnames=('self'))
|
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
|
return policy_params
|
|
|
|
@partial(jax.jit, static_argnames=('self'))
|
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
|
return jnp.array(0.0) |