From ca1ee980ef49eaf4808df651536b5fdcf3593242 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 21 Oct 2024 15:23:57 +0200 Subject: [PATCH] Fix: Bug in loss calc for TRPLLoss --- fancy_rl/objectives/trpl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fancy_rl/objectives/trpl.py b/fancy_rl/objectives/trpl.py index 4094307..ad72665 100644 --- a/fancy_rl/objectives/trpl.py +++ b/fancy_rl/objectives/trpl.py @@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss): def _trust_region_loss(self, tensordict): old_distribution = self.old_actor_network(tensordict) - raw_distribution = self.actor_network(tensordict) - return self.projection(self.actor_network, raw_distribution, old_distribution) + new_distribution = self.actor_network(tensordict) + return self.projection.get_trust_region_loss(new_distribution, old_distribution) def forward(self, tensordict: TensorDictBase) -> TensorDict: tensordict = tensordict.clone(False)