Add utils to handle all known space types
This commit is contained in:
parent
5a6069daf4
commit
a3cca71ac9
53
fancy_rl/utils.py
Normal file
53
fancy_rl/utils.py
Normal file
@ -0,0 +1,53 @@
|
||||
try:
|
||||
import gym
|
||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
||||
except ImportError:
|
||||
gym = None
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
||||
from torchrl.data.tensor_specs import (
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
|
||||
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
def is_discrete_space(action_space):
|
||||
discrete_types = (
|
||||
GymDiscrete, GymMultiDiscrete, GymMultiBinary,
|
||||
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
||||
)
|
||||
continuous_types = (
|
||||
GymBox, GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
return True
|
||||
elif isinstance(action_space, continuous_types):
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
||||
|
||||
def get_space_shape(action_space):
|
||||
if gym is not None:
|
||||
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types = (GymBox,)
|
||||
else:
|
||||
discrete_types = ()
|
||||
continuous_types = ()
|
||||
|
||||
discrete_types += (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
||||
continuous_types += (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
if isinstance(action_space, (GymDiscrete, GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, (GymMultiDiscrete, GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
||||
return (sum(action_space.nvec),)
|
||||
elif isinstance(action_space, (GymMultiBinary, GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, continuous_types):
|
||||
return action_space.shape
|
||||
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
Loading…
Reference in New Issue
Block a user