diff --git a/fancy_rl/utils.py b/fancy_rl/utils.py index 42b5c12..0469aaa 100644 --- a/fancy_rl/utils.py +++ b/fancy_rl/utils.py @@ -33,24 +33,28 @@ def is_discrete_space(action_space): raise ValueError(f"Unsupported action space type: {type(action_space)}") def get_space_shape(action_space): - if gym_available: - discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary) - continuous_types = (GymBox,) - else: - discrete_types = () - continuous_types = () + discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary, + DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec) + continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec) - discrete_types += (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary, - DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec) - continuous_types += (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec) + if gym_available: + discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary) + continuous_types += (GymBox,) if isinstance(action_space, discrete_types): - if isinstance(action_space, (GymDiscrete, GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)): + if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)): return (action_space.n,) - elif isinstance(action_space, (GymMultiDiscrete, GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)): + elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)): return (sum(action_space.nvec),) - elif isinstance(action_space, (GymMultiBinary, GymnasiumMultiBinary, BinaryDiscreteTensorSpec)): + elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)): return (action_space.n,) + elif gym_available: + if isinstance(action_space, GymDiscrete): + return (action_space.n,) + elif isinstance(action_space, GymMultiDiscrete): + return (sum(action_space.nvec),) + elif isinstance(action_space, GymMultiBinary): + return (action_space.n,) elif isinstance(action_space, continuous_types): return action_space.shape