diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index bfb7b3a..9e106d4 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -330,12 +330,14 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): trust_region_losses.append(trust_region_loss.item()) # 'Principle of least action' - action_loss = th.square(actions) + action_loss = th.mean(th.square(actions)) action_losses.append(action_loss) policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \ trust_region_loss + self.action_coef * action_loss + import pdb + pdb.set_trace() pg_losses.append(policy_loss.item()) loss = policy_loss + self.vf_coef * value_loss