metastable-baselines2/metastable_baselines2/common/distributions.py
2024-03-09 13:58:56 +01:00

46 lines
1.9 KiB
Python

from stable_baselines3.common.distributions import *
from metastable_baselines2.common.pca import PCA_Distribution
def _patched_make_proba_distribution(
action_space: spaces.Space, n_envs: int = 1, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space
:param action_space: the input action space
:param use_sde: Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param dist_kwargs: Keyword arguments to pass to the probability distribution
:return: the appropriate Distribution object
"""
assert not (use_sde and use_pca), 'Can not mix sde and pca!'
if dist_kwargs is None:
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
if use_sde:
cls = StateDependentNoiseDistribution
elif use_pca:
cls = PCA_Distribution
return cls(get_action_dim(action_space), n_envs=n_envs, **dist_kwargs)
else:
cls = DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
"Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
)
_orig_make_propa_distribution, make_proba_distribution = make_proba_distribution, _patched_make_proba_distribution