Fix: Minor bugs in time aware obs wrapper

This commit is contained in:
Dominik Moritz Roth 2023-06-11 13:47:38 +02:00
parent e44b0ed9ed
commit 2ad42f4132

View File

@ -28,10 +28,10 @@ class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorAr
assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict' assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict'
if env.observation_space.__class__ in [Box, OldBox]:
low = np.append(env.observation_space.low, 0.0) low = np.append(env.observation_space.low, 0.0)
high = np.append(env.observation_space.high, 1.0) high = np.append(env.observation_space.high, 1.0)
if env.observation_space.__class__ in [Box, OldBox]:
self.observation_space = Box(low, high, dtype=dtype) self.observation_space = Box(low, high, dtype=dtype)
else: else:
import pdb import pdb
@ -49,7 +49,7 @@ class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorAr
Returns: Returns:
The observation with the time step appended to (relative to total number of steps) The observation with the time step appended to (relative to total number of steps)
""" """
return np.append(observation, self.t / getattr(self.env, '_max_episode_steps')) return np.append(observation, self.t / self.env.spec.max_episode_steps)
def step(self, action): def step(self, action):
"""Steps through the environment, incrementing the time step. """Steps through the environment, incrementing the time step.