Make sure we work without original gym installed

This commit is contained in:
Dominik Moritz Roth 2024-06-02 12:08:48 +02:00
parent b4f89c9b7a
commit d51bf948d4

View File

@ -1,9 +1,3 @@
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
except ImportError:
gym = None
import gymnasium
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
from torchrl.data.tensor_specs import (
@ -11,16 +5,26 @@ from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
gym_available = True
except ImportError:
gym_available = False
def is_discrete_space(action_space):
discrete_types = (
GymDiscrete, GymMultiDiscrete, GymMultiBinary,
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
)
continuous_types = (
GymBox, GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
return True
elif isinstance(action_space, continuous_types):
@ -29,7 +33,7 @@ def is_discrete_space(action_space):
raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space):
if gym is not None:
if gym_available:
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types = (GymBox,)
else: