From dbd7c37da558eda79b6cc7bc6812456760017d59 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 27 May 2023 12:54:30 +0200 Subject: [PATCH] Also support old gym Box as observation_space (backwards compat) --- fancy_gym/utils/time_aware_observation.py | 29 ++++++++++++++++------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/fancy_gym/utils/time_aware_observation.py b/fancy_gym/utils/time_aware_observation.py index c1aea7f..12c3762 100644 --- a/fancy_gym/utils/time_aware_observation.py +++ b/fancy_gym/utils/time_aware_observation.py @@ -1,4 +1,5 @@ -from gymnasium.spaces import Box +from gymnasium.spaces import Box, Dict +from gym.spaces import Box as OldBox import gymnasium as gym import numpy as np @@ -6,27 +7,37 @@ import numpy as np class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): """Augment the observation with the current time step in the episode. - The observation space of the wrapped environment is assumed to be a flat :class:`Box`. - In particular, pixel observations are not supported. This wrapper will append the current timestep within the current episode to the observation. - The timestep will be indicated as a number between 0 and 1. + The observation space of the wrapped environment is assumed to be a flat :class:`Box` or flattable :class:`Dict`. + In particular, pixel observations are not supported. This wrapper will append the current progress within the current episode to the observation. + The progress will be indicated as a number between 0 and 1. """ def __init__(self, env: gym.Env, enforce_dtype_float32=False): - """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space. + """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` or flattable :class:`Dict` observation space. Args: env: The environment to apply the wrapper """ gym.utils.RecordConstructorArgs.__init__(self) gym.ObservationWrapper.__init__(self, env) - assert isinstance(env.observation_space, Box) + allowed_classes = [Box, OldBox, Dict] if enforce_dtype_float32: assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str( env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.' dtype = env.observation_space.dtype - low = np.append(self.observation_space.low, 0.0) - high = np.append(self.observation_space.high, np.inf) - self.observation_space = Box(low, high, dtype=dtype) + + assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict' + + low = np.append(env.observation_space.low, 0.0) + high = np.append(env.observation_space.high, 1.0) + + if env.observation_space.__class__ in [Box, OldBox]: + self.observation_space = Box(low, high, dtype=dtype) + else: + import pdb + pdb.set_trace() + exit + self.is_vector_env = getattr(env, "is_vector_env", False) def observation(self, observation):