Fix: Bug in loss calc for TRPLLoss
This commit is contained in:
		
							parent
							
								
									0c6e58634f
								
							
						
					
					
						commit
						ca1ee980ef
					
				@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _trust_region_loss(self, tensordict):
 | 
					    def _trust_region_loss(self, tensordict):
 | 
				
			||||||
        old_distribution = self.old_actor_network(tensordict)
 | 
					        old_distribution = self.old_actor_network(tensordict)
 | 
				
			||||||
        raw_distribution = self.actor_network(tensordict)
 | 
					        new_distribution = self.actor_network(tensordict)
 | 
				
			||||||
        return self.projection(self.actor_network, raw_distribution, old_distribution)
 | 
					        return self.projection.get_trust_region_loss(new_distribution, old_distribution)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, tensordict: TensorDictBase) -> TensorDict:
 | 
					    def forward(self, tensordict: TensorDictBase) -> TensorDict:
 | 
				
			||||||
        tensordict = tensordict.clone(False)
 | 
					        tensordict = tensordict.clone(False)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user