Tell distributions the n_envs (for e.g. Pink Noise pregen)
This commit is contained in:
parent
b2384e183c
commit
8f66a34c29
@ -2,7 +2,7 @@ from stable_baselines3.common.distributions import *
|
||||
from metastable_baselines2.common.pca import PCA_Distribution
|
||||
|
||||
def _patched_make_proba_distribution(
|
||||
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||
action_space: spaces.Space, n_envs: int = 1, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Distribution:
|
||||
"""
|
||||
Return an instance of Distribution for the correct type of action space
|
||||
@ -26,7 +26,7 @@ def _patched_make_proba_distribution(
|
||||
cls = PCA_Distribution
|
||||
else:
|
||||
cls = DiagGaussianDistribution
|
||||
return cls(get_action_dim(action_space), **dist_kwargs)
|
||||
return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
|
@ -92,6 +92,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
|
||||
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
||||
|
||||
policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs)
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
|
@ -527,8 +527,10 @@ class ActorCriticPolicy(BasePolicy):
|
||||
|
||||
self.policy_projection = policy_projection
|
||||
|
||||
self.n_envs = dist_kwargs.pop('n_envs', 1)
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
|
||||
self.action_dist = make_proba_distribution(action_space, self.n_envs, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user