Fix: Bug in loss calc for TRPLLoss
This commit is contained in:
parent
0c6e58634f
commit
ca1ee980ef
@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss):
|
|||||||
|
|
||||||
def _trust_region_loss(self, tensordict):
|
def _trust_region_loss(self, tensordict):
|
||||||
old_distribution = self.old_actor_network(tensordict)
|
old_distribution = self.old_actor_network(tensordict)
|
||||||
raw_distribution = self.actor_network(tensordict)
|
new_distribution = self.actor_network(tensordict)
|
||||||
return self.projection(self.actor_network, raw_distribution, old_distribution)
|
return self.projection.get_trust_region_loss(new_distribution, old_distribution)
|
||||||
|
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
||||||
tensordict = tensordict.clone(False)
|
tensordict = tensordict.clone(False)
|
||||||
|
Loading…
Reference in New Issue
Block a user