Also support old gym Box as observation_space (backwards compat)

This commit is contained in:
Dominik Moritz Roth 2023-05-27 12:54:30 +02:00
parent 29b458c7df
commit dbd7c37da5

View File

@ -1,4 +1,5 @@
from gymnasium.spaces import Box from gymnasium.spaces import Box, Dict
from gym.spaces import Box as OldBox
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -6,27 +7,37 @@ import numpy as np
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment the observation with the current time step in the episode. """Augment the observation with the current time step in the episode.
The observation space of the wrapped environment is assumed to be a flat :class:`Box`. The observation space of the wrapped environment is assumed to be a flat :class:`Box` or flattable :class:`Dict`.
In particular, pixel observations are not supported. This wrapper will append the current timestep within the current episode to the observation. In particular, pixel observations are not supported. This wrapper will append the current progress within the current episode to the observation.
The timestep will be indicated as a number between 0 and 1. The progress will be indicated as a number between 0 and 1.
""" """
def __init__(self, env: gym.Env, enforce_dtype_float32=False): def __init__(self, env: gym.Env, enforce_dtype_float32=False):
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space. """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` or flattable :class:`Dict` observation space.
Args: Args:
env: The environment to apply the wrapper env: The environment to apply the wrapper
""" """
gym.utils.RecordConstructorArgs.__init__(self) gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env) gym.ObservationWrapper.__init__(self, env)
assert isinstance(env.observation_space, Box) allowed_classes = [Box, OldBox, Dict]
if enforce_dtype_float32: if enforce_dtype_float32:
assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str( assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str(
env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.' env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.'
dtype = env.observation_space.dtype dtype = env.observation_space.dtype
low = np.append(self.observation_space.low, 0.0)
high = np.append(self.observation_space.high, np.inf) assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict'
low = np.append(env.observation_space.low, 0.0)
high = np.append(env.observation_space.high, 1.0)
if env.observation_space.__class__ in [Box, OldBox]:
self.observation_space = Box(low, high, dtype=dtype) self.observation_space = Box(low, high, dtype=dtype)
else:
import pdb
pdb.set_trace()
exit
self.is_vector_env = getattr(env, "is_vector_env", False) self.is_vector_env = getattr(env, "is_vector_env", False)
def observation(self, observation): def observation(self, observation):