From e3f4c511bf27e5e943f29b934cbebb53bdf233c3 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 23 Jan 2024 09:20:34 +0100 Subject: [PATCH] Better default HPs for TRPL --- metastable_baselines2/trpl/trpl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index ea16a50..ed819b1 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -87,7 +87,7 @@ class TRPL(BetterOnPolicyAlgorithm): n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, - clip_range: Union[float, Schedule] = 0.2, + clip_range: Union[float, Schedule, None] = None, clip_range_vf: Union[None, float, Schedule] = None, normalize_advantage: bool = True, ent_coef: float = 0.0, @@ -169,7 +169,11 @@ class TRPL(BetterOnPolicyAlgorithm): ) self.batch_size = batch_size self.n_epochs = n_epochs + if clip_range == False: + clip_range = None self.clip_range = clip_range + if clip_range_vf == False: + clip_range_vf = None self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage self.projection = castProjection(projection)