Tell distributions the n_envs (for e.g. Pink Noise pregen)

This commit is contained in:
Dominik Moritz Roth 2024-03-09 13:46:23 +01:00
parent b2384e183c
commit 8f66a34c29
3 changed files with 7 additions and 3 deletions

View File

@ -2,7 +2,7 @@ from stable_baselines3.common.distributions import *
from metastable_baselines2.common.pca import PCA_Distribution
def _patched_make_proba_distribution(
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
action_space: spaces.Space, n_envs: int = 1, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space
@ -26,7 +26,7 @@ def _patched_make_proba_distribution(
cls = PCA_Distribution
else:
cls = DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):

View File

@ -92,6 +92,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
assert not rollout_buffer_class and not rollout_buffer_kwargs
policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs)
super().__init__(
policy=policy,
env=env,

View File

@ -527,8 +527,10 @@ class ActorCriticPolicy(BasePolicy):
self.policy_projection = policy_projection
self.n_envs = dist_kwargs.pop('n_envs', 1)
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
self.action_dist = make_proba_distribution(action_space, self.n_envs, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
self._build(lr_schedule)