Tell distributions the n_envs (for e.g. Pink Noise pregen)
This commit is contained in:
parent
b2384e183c
commit
8f66a34c29
@ -2,7 +2,7 @@ from stable_baselines3.common.distributions import *
|
|||||||
from metastable_baselines2.common.pca import PCA_Distribution
|
from metastable_baselines2.common.pca import PCA_Distribution
|
||||||
|
|
||||||
def _patched_make_proba_distribution(
|
def _patched_make_proba_distribution(
|
||||||
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
action_space: spaces.Space, n_envs: int = 1, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||||
) -> Distribution:
|
) -> Distribution:
|
||||||
"""
|
"""
|
||||||
Return an instance of Distribution for the correct type of action space
|
Return an instance of Distribution for the correct type of action space
|
||||||
@ -26,7 +26,7 @@ def _patched_make_proba_distribution(
|
|||||||
cls = PCA_Distribution
|
cls = PCA_Distribution
|
||||||
else:
|
else:
|
||||||
cls = DiagGaussianDistribution
|
cls = DiagGaussianDistribution
|
||||||
return cls(get_action_dim(action_space), **dist_kwargs)
|
return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs)
|
||||||
elif isinstance(action_space, spaces.Discrete):
|
elif isinstance(action_space, spaces.Discrete):
|
||||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||||
|
@ -92,6 +92,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
||||||
|
|
||||||
|
policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
env=env,
|
env=env,
|
||||||
|
@ -527,8 +527,10 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
|
|
||||||
self.policy_projection = policy_projection
|
self.policy_projection = policy_projection
|
||||||
|
|
||||||
|
self.n_envs = dist_kwargs.pop('n_envs', 1)
|
||||||
|
|
||||||
# Action distribution
|
# Action distribution
|
||||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
|
self.action_dist = make_proba_distribution(action_space, self.n_envs, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
|
||||||
|
|
||||||
self._build(lr_schedule)
|
self._build(lr_schedule)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user