diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 5b85692..18e5104 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -92,7 +92,7 @@ class TRPL(BetterOnPolicyAlgorithm): normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, - max_grad_norm: float = 0.5, + max_grad_norm: Union[float, None] = None, use_sde: bool = False, sde_sample_freq: int = -1, use_pca: bool = False, @@ -306,7 +306,8 @@ class TRPL(BetterOnPolicyAlgorithm): self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() self._n_updates += 1