From 8f66a34c29d004aff5c5d8824b5df86fecd1e8d8 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 9 Mar 2024 13:46:23 +0100 Subject: [PATCH] Tell distributions the n_envs (for e.g. Pink Noise pregen) --- metastable_baselines2/common/distributions.py | 4 ++-- metastable_baselines2/common/on_policy_algorithm.py | 2 ++ metastable_baselines2/common/policies.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/metastable_baselines2/common/distributions.py b/metastable_baselines2/common/distributions.py index 735ba2f..ab69dd4 100644 --- a/metastable_baselines2/common/distributions.py +++ b/metastable_baselines2/common/distributions.py @@ -2,7 +2,7 @@ from stable_baselines3.common.distributions import * from metastable_baselines2.common.pca import PCA_Distribution def _patched_make_proba_distribution( - action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None + action_space: spaces.Space, n_envs: int = 1, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None ) -> Distribution: """ Return an instance of Distribution for the correct type of action space @@ -26,7 +26,7 @@ def _patched_make_proba_distribution( cls = PCA_Distribution else: cls = DiagGaussianDistribution - return cls(get_action_dim(action_space), **dist_kwargs) + return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index a145852..df3d239 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -92,6 +92,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): assert not rollout_buffer_class and not rollout_buffer_kwargs + policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs) + super().__init__( policy=policy, env=env, diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 937bcb0..3923eb6 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -527,8 +527,10 @@ class ActorCriticPolicy(BasePolicy): self.policy_projection = policy_projection + self.n_envs = dist_kwargs.pop('n_envs', 1) + # Action distribution - self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs) + self.action_dist = make_proba_distribution(action_space, self.n_envs, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs) self._build(lr_schedule)