Fix: Bug in loss calc for TRPLLoss

This commit is contained in:
Dominik Moritz Roth 2024-10-21 15:23:57 +02:00
parent 0c6e58634f
commit ca1ee980ef

View File

@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss):
def _trust_region_loss(self, tensordict): def _trust_region_loss(self, tensordict):
old_distribution = self.old_actor_network(tensordict) old_distribution = self.old_actor_network(tensordict)
raw_distribution = self.actor_network(tensordict) new_distribution = self.actor_network(tensordict)
return self.projection(self.actor_network, raw_distribution, old_distribution) return self.projection.get_trust_region_loss(new_distribution, old_distribution)
def forward(self, tensordict: TensorDictBase) -> TensorDict: def forward(self, tensordict: TensorDictBase) -> TensorDict:
tensordict = tensordict.clone(False) tensordict = tensordict.clone(False)