From 1d3c2fe005b68f1fa1ebad8de9c698e6aa909720 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 28 Aug 2022 00:26:44 +0200 Subject: [PATCH] Allow completely disabling some PPO features (for TRPL) --- metastable_baselines/ppo/ppo.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index ceed889..57b97b6 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -94,12 +94,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, - clip_range: Union[float, Schedule] = 0.2, + clip_range: Union[None, float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, - max_grad_norm: float = 0.5, + max_grad_norm: Union[None, float] = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, target_kl: Optional[float] = None, @@ -275,16 +275,22 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' # clipped surrogate loss - surrogate_loss_1 = advantages * ratio - surrogate_loss_2 = advantages * \ - th.clamp(ratio, 1 - clip_range, 1 + clip_range) - surrogate_loss = - \ - th.min(surrogate_loss_1, surrogate_loss_2).mean() + if self.clip_range is None: + surrogate_loss = advantages * ratio + else: + surrogate_loss_1 = advantages * ratio + surrogate_loss_2 = advantages * \ + th.clamp(ratio, 1 - clip_range, 1 + clip_range) + surrogate_loss = - \ + th.min(surrogate_loss_1, surrogate_loss_2).mean() surrogate_losses.append(surrogate_loss.item()) - clip_fraction = th.mean( - (th.abs(ratio - 1) > clip_range).float()).item() + if self.clip_range is None: + clip_fraction = 0 + else: + clip_fraction = th.mean( + (th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: @@ -341,8 +347,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_( - self.policy.parameters(), self.max_grad_norm) + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if not continue_training: @@ -380,7 +387,8 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - self.logger.record("train/clip_range", clip_range) + if self.clip_range is not None: + self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf)