From e4a8cfc34989e35d801143b5148db831f9e2f27d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 11:16:29 +0200 Subject: [PATCH] Implemented action_loss --- metastable_baselines/ppo/ppo.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index a523601..07aca36 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -99,6 +99,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, + action_coef: float = 0.0, max_grad_norm: Union[None, float] = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, @@ -181,6 +182,8 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.normalize_advantage = normalize_advantage self.target_kl = target_kl + self.action_coef = action_coef + # Different from PPO: self.projection = projection self._global_steps = 0 @@ -221,6 +224,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): surrogate_losses = [] entropy_losses = [] trust_region_losses = [] + action_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] @@ -325,7 +329,13 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): trust_region_losses.append(trust_region_loss.item()) - policy_loss = surrogate_loss + self.ent_coef * entropy_loss + trust_region_loss + # 'Principle of least action' + action_loss = th.square(actions) + + action_losses.append(action_loss) + + policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \ + trust_region_loss + self.action_coef * action_loss pg_losses.append(policy_loss.item()) loss = policy_loss + self.vf_coef * value_loss @@ -369,6 +379,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.logger.record("train/trust_region_loss", np.mean(trust_region_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/action_loss", np.mean(action_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions))