From a23b44752e58fcd64ad4c757d4c98c2776cffd28 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 11 Jun 2023 17:38:16 +0200 Subject: [PATCH] Implement support for Dict spaces for time_aware_observation-wrapper --- fancy_gym/utils/time_aware_observation.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/fancy_gym/utils/time_aware_observation.py b/fancy_gym/utils/time_aware_observation.py index a042438..61290dd 100644 --- a/fancy_gym/utils/time_aware_observation.py +++ b/fancy_gym/utils/time_aware_observation.py @@ -2,6 +2,7 @@ from gymnasium.spaces import Box, Dict from gym.spaces import Box as OldBox import gymnasium as gym import numpy as np +import copy class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): @@ -24,19 +25,21 @@ class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorAr if enforce_dtype_float32: assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str( env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.' - dtype = env.observation_space.dtype - assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict' if env.observation_space.__class__ in [Box, OldBox]: + dtype = env.observation_space.dtype + low = np.append(env.observation_space.low, 0.0) high = np.append(env.observation_space.high, 1.0) self.observation_space = Box(low, high, dtype=dtype) else: - import pdb - pdb.set_trace() - exit + spaces = copy.copy(env.observation_space.spaces) + dtype = np.float64 + spaces['time_awareness'] = Box(0, 1, dtype=dtype) + + self.observation_space = Dict(spaces) self.is_vector_env = getattr(env, "is_vector_env", False)