diff --git a/metastable_baselines2/common/distributions.py b/metastable_baselines2/common/distributions.py index 735ba2f..ab69dd4 100644 --- a/metastable_baselines2/common/distributions.py +++ b/metastable_baselines2/common/distributions.py @@ -2,7 +2,7 @@ from stable_baselines3.common.distributions import * from metastable_baselines2.common.pca import PCA_Distribution def _patched_make_proba_distribution( - action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None + action_space: spaces.Space, n_envs: int = 1, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None ) -> Distribution: """ Return an instance of Distribution for the correct type of action space @@ -26,7 +26,7 @@ def _patched_make_proba_distribution( cls = PCA_Distribution else: cls = DiagGaussianDistribution - return cls(get_action_dim(action_space), **dist_kwargs) + return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index a145852..df3d239 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -92,6 +92,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): assert not rollout_buffer_class and not rollout_buffer_kwargs + policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs) + super().__init__( policy=policy, env=env, diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 937bcb0..3923eb6 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -527,8 +527,10 @@ class ActorCriticPolicy(BasePolicy): self.policy_projection = policy_projection + self.n_envs = dist_kwargs.pop('n_envs', 1) + # Action distribution - self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs) + self.action_dist = make_proba_distribution(action_space, self.n_envs, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs) self._build(lr_schedule)