diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 83d78ab..94a020f 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -114,6 +114,8 @@ class TRPL(BetterOnPolicyAlgorithm): _init_setup_model: bool = True, ): self.projection_class = castProjection(projection_class) + if projection_kwargs is None: + projection_kwargs = {} self.projection_kwargs = projection_kwargs self.projection = self.projection_class(**self.projection_kwargs) if policy_kwargs is None: