diff --git a/fancy_rl/objectives/trpl.py b/fancy_rl/objectives/trpl.py index 4094307..ad72665 100644 --- a/fancy_rl/objectives/trpl.py +++ b/fancy_rl/objectives/trpl.py @@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss): def _trust_region_loss(self, tensordict): old_distribution = self.old_actor_network(tensordict) - raw_distribution = self.actor_network(tensordict) - return self.projection(self.actor_network, raw_distribution, old_distribution) + new_distribution = self.actor_network(tensordict) + return self.projection.get_trust_region_loss(new_distribution, old_distribution) def forward(self, tensordict: TensorDictBase) -> TensorDict: tensordict = tensordict.clone(False)