Implement support for Dict spaces for time_aware_observation-wrapper

This commit is contained in:
Dominik Moritz Roth 2023-06-11 17:38:16 +02:00
parent abeb963b4e
commit a23b44752e

View File

@ -2,6 +2,7 @@ from gymnasium.spaces import Box, Dict
from gym.spaces import Box as OldBox from gym.spaces import Box as OldBox
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import copy
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
@ -24,19 +25,21 @@ class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorAr
if enforce_dtype_float32: if enforce_dtype_float32:
assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str( assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str(
env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.' env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.'
dtype = env.observation_space.dtype
assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict' assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict'
if env.observation_space.__class__ in [Box, OldBox]: if env.observation_space.__class__ in [Box, OldBox]:
dtype = env.observation_space.dtype
low = np.append(env.observation_space.low, 0.0) low = np.append(env.observation_space.low, 0.0)
high = np.append(env.observation_space.high, 1.0) high = np.append(env.observation_space.high, 1.0)
self.observation_space = Box(low, high, dtype=dtype) self.observation_space = Box(low, high, dtype=dtype)
else: else:
import pdb spaces = copy.copy(env.observation_space.spaces)
pdb.set_trace() dtype = np.float64
exit spaces['time_awareness'] = Box(0, 1, dtype=dtype)
self.observation_space = Dict(spaces)
self.is_vector_env = getattr(env, "is_vector_env", False) self.is_vector_env = getattr(env, "is_vector_env", False)