Error when calculating action_loss
This commit is contained in:
parent
2c14edd3b0
commit
00dbc9bdd8
@ -332,7 +332,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
# 'Principle of least action'
|
||||
action_loss = th.mean(th.square(actions))
|
||||
|
||||
action_losses.append(action_loss)
|
||||
action_losses.append(action_loss.item())
|
||||
|
||||
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
|
||||
trust_region_loss + self.action_coef * action_loss
|
||||
|
Loading…
Reference in New Issue
Block a user