From eb881559d6149553ec4770ae2a3fd80d8c2d7b0b Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 28 Aug 2022 02:07:18 +0200 Subject: [PATCH] Support clip_range None --- metastable_baselines/ppo/ppo.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 57b97b6..a523601 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -192,7 +192,8 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): super()._setup_model() # Initialize schedules for policy/value clipping - self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range is not None: + self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" @@ -208,7 +209,10 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range - clip_range = self.clip_range(self._current_progress_remaining) + if self.clip_range: + clip_range = self.clip_range(self._current_progress_remaining) + else: + clip_range = None # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf( @@ -276,7 +280,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' # clipped surrogate loss if self.clip_range is None: - surrogate_loss = advantages * ratio + surrogate_loss = -(advantages * ratio).mean() else: surrogate_loss_1 = advantages * ratio surrogate_loss_2 = advantages * \ @@ -286,7 +290,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): surrogate_losses.append(surrogate_loss.item()) - if self.clip_range is None: + if clip_range is None: clip_fraction = 0 else: clip_fraction = th.mean(