diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 94a020f..cd1ba96 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -121,6 +121,10 @@ class TRPL(BetterOnPolicyAlgorithm): if policy_kwargs is None: policy_kwargs = {} policy_kwargs['policy_projection'] = self.projection + if 'dist_kwargs' not in policy_kwargs: + policy_kwargs['dist_kwargs'] = {} + if use_pca: + policy_kwargs['dist_kwargs']['msqrt_induces_full'] = self.projection_class == WassersteinProjectionLayer super().__init__( policy,