diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index edd3be8..83d78ab 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -116,6 +116,8 @@ class TRPL(BetterOnPolicyAlgorithm): self.projection_class = castProjection(projection_class) self.projection_kwargs = projection_kwargs self.projection = self.projection_class(**self.projection_kwargs) + if policy_kwargs is None: + policy_kwargs = {} policy_kwargs['policy_projection'] = self.projection super().__init__(