13 lines
802 B
Python
13 lines
802 B
Python
import torch
|
|
from .base_projection import BaseProjection
|
|
from typing import Dict
|
|
|
|
class IdentityProjection(BaseProjection):
|
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
|
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
|
|
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
return policy_params
|
|
|
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
return torch.tensor(0.0, device=next(iter(policy_params.values())).device) |