Tell PCA to use mSqrt when using W2 proj
This commit is contained in:
parent
4831fbb7e1
commit
ada300fa63
@ -121,6 +121,10 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
if policy_kwargs is None:
|
||||
policy_kwargs = {}
|
||||
policy_kwargs['policy_projection'] = self.projection
|
||||
if 'dist_kwargs' not in policy_kwargs:
|
||||
policy_kwargs['dist_kwargs'] = {}
|
||||
if use_pca:
|
||||
policy_kwargs['dist_kwargs']['msqrt_induces_full'] = self.projection_class == WassersteinProjectionLayer
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
|
Loading…
Reference in New Issue
Block a user