diff --git a/metastable_baselines2/common/distributions.py b/metastable_baselines2/common/distributions.py index ab69dd4..fd80e06 100644 --- a/metastable_baselines2/common/distributions.py +++ b/metastable_baselines2/common/distributions.py @@ -24,9 +24,10 @@ def _patched_make_proba_distribution( cls = StateDependentNoiseDistribution elif use_pca: cls = PCA_Distribution + return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs) else: cls = DiagGaussianDistribution - return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs) + return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete):