diff --git a/fancy_rl/utils.py b/fancy_rl/utils.py index aace333..42b5c12 100644 --- a/fancy_rl/utils.py +++ b/fancy_rl/utils.py @@ -1,9 +1,3 @@ -try: - import gym - from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox -except ImportError: - gym = None - import gymnasium from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox from torchrl.data.tensor_specs import ( @@ -11,16 +5,26 @@ from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec ) +try: + import gym + from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox + gym_available = True +except ImportError: + gym_available = False + def is_discrete_space(action_space): discrete_types = ( - GymDiscrete, GymMultiDiscrete, GymMultiBinary, GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary, DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec ) continuous_types = ( - GymBox, GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec + GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec ) + if gym_available: + discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary) + continuous_types += (GymBox,) + if isinstance(action_space, discrete_types): return True elif isinstance(action_space, continuous_types): @@ -29,7 +33,7 @@ def is_discrete_space(action_space): raise ValueError(f"Unsupported action space type: {type(action_space)}") def get_space_shape(action_space): - if gym is not None: + if gym_available: discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary) continuous_types = (GymBox,) else: