Tell PCA to use mSqrt when using W2 proj

This commit is contained in:
Dominik Moritz Roth 2024-03-30 14:44:20 +01:00
parent 4831fbb7e1
commit ada300fa63

View File

@ -121,6 +121,10 @@ class TRPL(BetterOnPolicyAlgorithm):
if policy_kwargs is None:
policy_kwargs = {}
policy_kwargs['policy_projection'] = self.projection
if 'dist_kwargs' not in policy_kwargs:
policy_kwargs['dist_kwargs'] = {}
if use_pca:
policy_kwargs['dist_kwargs']['msqrt_induces_full'] = self.projection_class == WassersteinProjectionLayer
super().__init__(
policy,