diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index ea16a50..ed819b1 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -87,7 +87,7 @@ class TRPL(BetterOnPolicyAlgorithm): n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, - clip_range: Union[float, Schedule] = 0.2, + clip_range: Union[float, Schedule, None] = None, clip_range_vf: Union[None, float, Schedule] = None, normalize_advantage: bool = True, ent_coef: float = 0.0, @@ -169,7 +169,11 @@ class TRPL(BetterOnPolicyAlgorithm): ) self.batch_size = batch_size self.n_epochs = n_epochs + if clip_range == False: + clip_range = None self.clip_range = clip_range + if clip_range_vf == False: + clip_range_vf = None self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage self.projection = castProjection(projection)