Make sure we work without original gym installed

This commit is contained in:
Dominik Moritz Roth 2024-06-02 12:08:48 +02:00
parent b4f89c9b7a
commit d51bf948d4

View File

@ -1,9 +1,3 @@
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
except ImportError:
gym = None
import gymnasium import gymnasium
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
from torchrl.data.tensor_specs import ( from torchrl.data.tensor_specs import (
@ -11,16 +5,26 @@ from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
) )
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
gym_available = True
except ImportError:
gym_available = False
def is_discrete_space(action_space): def is_discrete_space(action_space):
discrete_types = ( discrete_types = (
GymDiscrete, GymMultiDiscrete, GymMultiBinary,
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary, GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
) )
continuous_types = ( continuous_types = (
GymBox, GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
) )
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types): if isinstance(action_space, discrete_types):
return True return True
elif isinstance(action_space, continuous_types): elif isinstance(action_space, continuous_types):
@ -29,7 +33,7 @@ def is_discrete_space(action_space):
raise ValueError(f"Unsupported action space type: {type(action_space)}") raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space): def get_space_shape(action_space):
if gym is not None: if gym_available:
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary) discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types = (GymBox,) continuous_types = (GymBox,)
else: else: