Make sure we work without original gym installed
This commit is contained in:
parent
b4f89c9b7a
commit
d51bf948d4
@ -1,9 +1,3 @@
|
|||||||
try:
|
|
||||||
import gym
|
|
||||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
|
||||||
except ImportError:
|
|
||||||
gym = None
|
|
||||||
|
|
||||||
import gymnasium
|
import gymnasium
|
||||||
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
||||||
from torchrl.data.tensor_specs import (
|
from torchrl.data.tensor_specs import (
|
||||||
@ -11,16 +5,26 @@ from torchrl.data.tensor_specs import (
|
|||||||
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gym
|
||||||
|
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
||||||
|
gym_available = True
|
||||||
|
except ImportError:
|
||||||
|
gym_available = False
|
||||||
|
|
||||||
def is_discrete_space(action_space):
|
def is_discrete_space(action_space):
|
||||||
discrete_types = (
|
discrete_types = (
|
||||||
GymDiscrete, GymMultiDiscrete, GymMultiBinary,
|
|
||||||
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
||||||
)
|
)
|
||||||
continuous_types = (
|
continuous_types = (
|
||||||
GymBox, GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if gym_available:
|
||||||
|
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||||
|
continuous_types += (GymBox,)
|
||||||
|
|
||||||
if isinstance(action_space, discrete_types):
|
if isinstance(action_space, discrete_types):
|
||||||
return True
|
return True
|
||||||
elif isinstance(action_space, continuous_types):
|
elif isinstance(action_space, continuous_types):
|
||||||
@ -29,7 +33,7 @@ def is_discrete_space(action_space):
|
|||||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
||||||
|
|
||||||
def get_space_shape(action_space):
|
def get_space_shape(action_space):
|
||||||
if gym is not None:
|
if gym_available:
|
||||||
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||||
continuous_types = (GymBox,)
|
continuous_types = (GymBox,)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user