Error when calculating action_loss

This commit is contained in:
Dominik Moritz Roth 2022-09-12 22:28:57 +02:00
parent 2c14edd3b0
commit 00dbc9bdd8

View File

@ -332,7 +332,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
# 'Principle of least action'
action_loss = th.mean(th.square(actions))
action_losses.append(action_loss)
action_losses.append(action_loss.item())
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
trust_region_loss + self.action_coef * action_loss