From 0aeea4e2e564b8ed974875bef5dafab64a716fcb Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 11:44:01 +0200 Subject: [PATCH] Fixed Bug: Wrong dimensions for action_loss --- metastable_baselines/ppo/ppo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index bfb7b3a..9e106d4 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -330,12 +330,14 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): trust_region_losses.append(trust_region_loss.item()) # 'Principle of least action' - action_loss = th.square(actions) + action_loss = th.mean(th.square(actions)) action_losses.append(action_loss) policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \ trust_region_loss + self.action_coef * action_loss + import pdb + pdb.set_trace() pg_losses.append(policy_loss.item()) loss = policy_loss + self.vf_coef * value_loss