diff --git a/fancy_rl/objectives/trpl.py b/fancy_rl/objectives/trpl.py index 10abc53..4094307 100644 --- a/fancy_rl/objectives/trpl.py +++ b/fancy_rl/objectives/trpl.py @@ -38,100 +38,40 @@ from torchrl.objectives.value import ( ) from torchrl.objectives.ppo import PPOLoss +from fancy_rl.projections import get_projection class TRPLLoss(PPOLoss): - @dataclass - class _AcceptedKeys: - """Maintains default values for all configurable tensordict keys. - - This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their - default values - - Attributes: - advantage (NestedKey): The input tensordict key where the advantage is expected. - Will be used for the underlying value estimator. Defaults to ``"advantage"``. - value_target (NestedKey): The input tensordict key where the target state value is expected. - Will be used for the underlying value estimator Defaults to ``"value_target"``. - value (NestedKey): The input tensordict key where the state value is expected. - Will be used for the underlying value estimator. Defaults to ``"state_value"``. - sample_log_prob (NestedKey): The input tensordict key where the - sample log probability is expected. Defaults to ``"sample_log_prob"``. - action (NestedKey): The input tensordict key where the action is expected. - Defaults to ``"action"``. - reward (NestedKey): The input tensordict key where the reward is expected. - Will be used for the underlying value estimator. Defaults to ``"reward"``. - done (NestedKey): The key in the input TensorDict that indicates - whether a trajectory is done. Will be used for the underlying value estimator. - Defaults to ``"done"``. - terminated (NestedKey): The key in the input TensorDict that indicates - whether a trajectory is terminated. Will be used for the underlying value estimator. - Defaults to ``"terminated"``. - """ - - advantage: NestedKey = "advantage" - value_target: NestedKey = "value_target" - value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" - action: NestedKey = "action" - reward: NestedKey = "reward" - done: NestedKey = "done" - terminated: NestedKey = "terminated" - - default_keys = _AcceptedKeys() - default_value_estimator = ValueEstimators.GAE - - - actor_network: TensorDictModule - critic_network: TensorDictModule - actor_network_params: TensorDictParams - critic_network_params: TensorDictParams - target_actor_network_params: TensorDictParams - target_critic_network_params: TensorDictParams - def __init__( self, - actor_network: ProbabilisticTensorDictSequential | None = None, - critic_network: TensorDictModule | None = None, - trust_region_layer: any | None = None, - entropy_bonus: bool = True, - samples_mc_entropy: int = 1, + actor_network: ProbabilisticTensorDictSequential, + old_actor_network: ProbabilisticTensorDictSequential, + critic_network: TensorDictModule, + projection: any, entropy_coef: float = 0.01, critic_coef: float = 1.0, trust_region_coef: float = 10.0, - loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - gamma: float = None, - separate_losses: bool = False, - reduction: str = None, - clip_value: bool | float | None = None, **kwargs, ): - self.trust_region_layer = trust_region_layer - self.trust_region_coef = trust_region_coef - - super(TRPLLoss, self).__init__( - actor_network, - critic_network, - entropy_bonus=entropy_bonus, - samples_mc_entropy=samples_mc_entropy, + super().__init__( + actor_network=actor_network, + critic_network=critic_network, entropy_coef=entropy_coef, critic_coef=critic_coef, - loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, - gamma=gamma, - separate_losses=separate_losses, - reduction=reduction, - clip_value=clip_value, **kwargs, ) + self.old_actor_network = old_actor_network + self.projection = projection + self.trust_region_coef = trust_region_coef @property def out_keys(self): if self._out_keys is None: - keys = ["loss_objective"] + keys = ["loss_objective", "tr_loss"] if self.entropy_bonus: keys.extend(["entropy", "loss_entropy"]) - if self.loss_critic: + if self.critic_coef: keys.append("loss_critic") keys.append("ESS") self._out_keys = keys @@ -141,8 +81,12 @@ class TRPLLoss(PPOLoss): def out_keys(self, values): self._out_keys = values - @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def _trust_region_loss(self, tensordict): + old_distribution = self.old_actor_network(tensordict) + raw_distribution = self.actor_network(tensordict) + return self.projection(self.actor_network, raw_distribution, old_distribution) + + def forward(self, tensordict: TensorDictBase) -> TensorDict: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: @@ -159,11 +103,8 @@ class TRPLLoss(PPOLoss): log_weight, dist, kl_approx = self._log_weight(tensordict) trust_region_loss_unscaled = self._trust_region_loss(tensordict) - # ESS for logging + with torch.no_grad(): - # In theory, ESS should be computed on particles sampled from the same source. Here we sample according - # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion - # of the weights. lw = log_weight.squeeze() ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() batch = log_weight.shape[0] @@ -194,8 +135,3 @@ class TRPLLoss(PPOLoss): batch_size=[], ) return td_out - - def _trust_region_loss(self, tensordict): - old_distribution = - raw_distribution = - return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution) \ No newline at end of file