Tell PCA to use mSqrt when using W2 proj
This commit is contained in:
parent
4831fbb7e1
commit
ada300fa63
@ -121,6 +121,10 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
if policy_kwargs is None:
|
if policy_kwargs is None:
|
||||||
policy_kwargs = {}
|
policy_kwargs = {}
|
||||||
policy_kwargs['policy_projection'] = self.projection
|
policy_kwargs['policy_projection'] = self.projection
|
||||||
|
if 'dist_kwargs' not in policy_kwargs:
|
||||||
|
policy_kwargs['dist_kwargs'] = {}
|
||||||
|
if use_pca:
|
||||||
|
policy_kwargs['dist_kwargs']['msqrt_induces_full'] = self.projection_class == WassersteinProjectionLayer
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy,
|
||||||
|
Loading…
Reference in New Issue
Block a user