itpal_jax/itpal_jax/exception_projection.py
2024-12-11 18:33:40 +01:00

20 lines
689 B
Python

from .base_projection import BaseProjection
from functools import partial
import jax.numpy as jnp
from typing import Dict
class ExceptionProjection(BaseProjection):
def __init__(self, msg, *args, **kwargs):
self.msg = msg
raise Exception(msg)
def project(self, policy_params: Dict[str, jnp.ndarray],
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
pass
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
pass
def makeExceptionProjection(msg: str):
return partial(ExceptionProjection, msg)