diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index c1329e8..5b85692 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -253,8 +253,12 @@ class TRPL(BetterOnPolicyAlgorithm): # Logging pg_losses.append(surrogate_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + if self.clip_range is not None: + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + else: + clip_fraction = 0 clip_fractions.append(clip_fraction) + if self.clip_range_vf is None: # No clipping