From ada300fa63e5c3ce41316f64da59abde2fb18348 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 30 Mar 2024 14:44:20 +0100 Subject: [PATCH] Tell PCA to use mSqrt when using W2 proj --- metastable_baselines2/trpl/trpl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 94a020f..cd1ba96 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -121,6 +121,10 @@ class TRPL(BetterOnPolicyAlgorithm): if policy_kwargs is None: policy_kwargs = {} policy_kwargs['policy_projection'] = self.projection + if 'dist_kwargs' not in policy_kwargs: + policy_kwargs['dist_kwargs'] = {} + if use_pca: + policy_kwargs['dist_kwargs']['msqrt_induces_full'] = self.projection_class == WassersteinProjectionLayer super().__init__( policy,