Make sure we work without original gym installed
This commit is contained in:
parent
b4f89c9b7a
commit
d51bf948d4
@ -1,9 +1,3 @@
|
||||
try:
|
||||
import gym
|
||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
||||
except ImportError:
|
||||
gym = None
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
||||
from torchrl.data.tensor_specs import (
|
||||
@ -11,16 +5,26 @@ from torchrl.data.tensor_specs import (
|
||||
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
try:
|
||||
import gym
|
||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
||||
gym_available = True
|
||||
except ImportError:
|
||||
gym_available = False
|
||||
|
||||
def is_discrete_space(action_space):
|
||||
discrete_types = (
|
||||
GymDiscrete, GymMultiDiscrete, GymMultiBinary,
|
||||
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
||||
)
|
||||
continuous_types = (
|
||||
GymBox, GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
if gym_available:
|
||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types += (GymBox,)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
return True
|
||||
elif isinstance(action_space, continuous_types):
|
||||
@ -29,7 +33,7 @@ def is_discrete_space(action_space):
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
||||
|
||||
def get_space_shape(action_space):
|
||||
if gym is not None:
|
||||
if gym_available:
|
||||
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types = (GymBox,)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user