Error when calculating action_loss
This commit is contained in:
parent
2c14edd3b0
commit
00dbc9bdd8
@ -332,7 +332,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
# 'Principle of least action'
|
# 'Principle of least action'
|
||||||
action_loss = th.mean(th.square(actions))
|
action_loss = th.mean(th.square(actions))
|
||||||
|
|
||||||
action_losses.append(action_loss)
|
action_losses.append(action_loss.item())
|
||||||
|
|
||||||
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
|
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
|
||||||
trust_region_loss + self.action_coef * action_loss
|
trust_region_loss + self.action_coef * action_loss
|
||||||
|
Loading…
Reference in New Issue
Block a user