From 8bbb01504bedb5ad60d29e218c4a5d7e4a152d20 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 23 Jan 2024 13:32:38 +0100 Subject: [PATCH] Fix bug when self.max_grad_norm=None --- metastable_baselines2/trpl/trpl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 5b85692..18e5104 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -92,7 +92,7 @@ class TRPL(BetterOnPolicyAlgorithm): normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, - max_grad_norm: float = 0.5, + max_grad_norm: Union[float, None] = None, use_sde: bool = False, sde_sample_freq: int = -1, use_pca: bool = False, @@ -306,7 +306,8 @@ class TRPL(BetterOnPolicyAlgorithm): self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() self._n_updates += 1