20 lines
689 B
Python
20 lines
689 B
Python
from .base_projection import BaseProjection
|
|
from functools import partial
|
|
import jax.numpy as jnp
|
|
from typing import Dict
|
|
|
|
class ExceptionProjection(BaseProjection):
|
|
def __init__(self, msg, *args, **kwargs):
|
|
self.msg = msg
|
|
raise Exception(msg)
|
|
|
|
def project(self, policy_params: Dict[str, jnp.ndarray],
|
|
old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
|
|
pass
|
|
|
|
def get_trust_region_loss(self, policy_params: Dict[str, jnp.ndarray],
|
|
proj_policy_params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
|
|
pass
|
|
|
|
def makeExceptionProjection(msg: str):
|
|
return partial(ExceptionProjection, msg) |