diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 30cfab4..b65c25f 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -111,7 +111,8 @@ def make_proba_distribution( if dist_kwargs is None: dist_kwargs = {} - dist_kwargs['use_sde'] = use_sde + if not use_pca: + dist_kwargs['use_sde'] = use_sde if isinstance(action_space, gym.spaces.Box): assert len( diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 108108b..644c56a 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -156,6 +156,7 @@ class ActorCriticPolicy(BasePolicy): "sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) self.use_sde = use_sde + print('[i] Use PCA? '+['No', 'Yes'][use_pca]) self.use_pca = use_pca self.dist_kwargs = dist_kwargs