Fix: Tried to reference gym space classes even if no gym avaible

This commit is contained in:
Dominik Moritz Roth 2024-08-30 08:05:41 +02:00
parent 906240e145
commit 71cb8593d9

View File

@ -33,23 +33,27 @@ def is_discrete_space(action_space):
raise ValueError(f"Unsupported action space type: {type(action_space)}") raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space): def get_space_shape(action_space):
if gym_available: discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types = (GymBox,)
else:
discrete_types = ()
continuous_types = ()
discrete_types += (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec) DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
continuous_types += (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec) continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types): if isinstance(action_space, discrete_types):
if isinstance(action_space, (GymDiscrete, GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)): if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
return (action_space.n,) return (action_space.n,)
elif isinstance(action_space, (GymMultiDiscrete, GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)): elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
return (sum(action_space.nvec),) return (sum(action_space.nvec),)
elif isinstance(action_space, (GymMultiBinary, GymnasiumMultiBinary, BinaryDiscreteTensorSpec)): elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
return (action_space.n,)
elif gym_available:
if isinstance(action_space, GymDiscrete):
return (action_space.n,)
elif isinstance(action_space, GymMultiDiscrete):
return (sum(action_space.nvec),)
elif isinstance(action_space, GymMultiBinary):
return (action_space.n,) return (action_space.n,)
elif isinstance(action_space, continuous_types): elif isinstance(action_space, continuous_types):
return action_space.shape return action_space.shape