46 lines
1.9 KiB
Python
46 lines
1.9 KiB
Python
from stable_baselines3.common.distributions import *
|
|
from metastable_baselines2.common.pca import PCA_Distribution
|
|
|
|
def _patched_make_proba_distribution(
|
|
action_space: spaces.Space, n_envs: int = 1, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
|
) -> Distribution:
|
|
"""
|
|
Return an instance of Distribution for the correct type of action space
|
|
|
|
:param action_space: the input action space
|
|
:param use_sde: Force the use of StateDependentNoiseDistribution
|
|
instead of DiagGaussianDistribution
|
|
:param dist_kwargs: Keyword arguments to pass to the probability distribution
|
|
:return: the appropriate Distribution object
|
|
"""
|
|
|
|
assert not (use_sde and use_pca), 'Can not mix sde and pca!'
|
|
|
|
if dist_kwargs is None:
|
|
dist_kwargs = {}
|
|
|
|
if isinstance(action_space, spaces.Box):
|
|
if use_sde:
|
|
cls = StateDependentNoiseDistribution
|
|
elif use_pca:
|
|
cls = PCA_Distribution
|
|
return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs)
|
|
else:
|
|
cls = DiagGaussianDistribution
|
|
return cls(get_action_dim(action_space), **dist_kwargs)
|
|
elif isinstance(action_space, spaces.Discrete):
|
|
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
|
elif isinstance(action_space, spaces.MultiDiscrete):
|
|
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
|
|
elif isinstance(action_space, spaces.MultiBinary):
|
|
return BernoulliDistribution(action_space.n, **dist_kwargs)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Error: probability distribution, not implemented for action space"
|
|
f"of type {type(action_space)}."
|
|
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
|
|
)
|
|
|
|
|
|
_orig_make_propa_distribution, make_proba_distribution = make_proba_distribution, _patched_make_proba_distribution
|