Fix: Make wrappers work with BB and Dict-Space

This commit is contained in:
Dominik Moritz Roth 2023-06-18 14:25:20 +02:00
parent b032dec5fe
commit 9ade0dcdc4
3 changed files with 150 additions and 25 deletions

127
fancy_gym/utils/wrappers.py Normal file
View File

@ -0,0 +1,127 @@
from gymnasium.spaces import Box, Dict, flatten, flatten_space
from gym.spaces import Box as OldBox
import gymnasium as gym
import numpy as np
import copy
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment the observation with the current time step in the episode.
The observation space of the wrapped environment is assumed to be a flat :class:`Box` or flattable :class:`Dict`.
In particular, pixel observations are not supported. This wrapper will append the current progress within the current episode to the observation.
The progress will be indicated as a number between 0 and 1.
"""
def __init__(self, env: gym.Env, enforce_dtype_float32=False):
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` or flattable :class:`Dict` observation space.
Args:
env: The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
allowed_classes = [Box, OldBox, Dict]
if enforce_dtype_float32:
assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str(
env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.'
assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict'
if env.observation_space.__class__ in [Box, OldBox]:
dtype = env.observation_space.dtype
low = np.append(env.observation_space.low, 0.0)
high = np.append(env.observation_space.high, 1.0)
self.observation_space = Box(low, high, dtype=dtype)
else:
spaces = copy.copy(env.observation_space.spaces)
dtype = np.float64
spaces['time_awareness'] = Box(0, 1, dtype=dtype)
self.observation_space = Dict(spaces)
self.is_vector_env = getattr(env, "is_vector_env", False)
def observation(self, observation):
"""Adds to the observation with the current time step.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time step appended to (relative to total number of steps)
"""
if self.observation_space.__class__ in [Box, OldBox]:
return np.append(observation, self.t / self.env.spec.max_episode_steps)
else:
obs = copy.copy(observation)
obs['time_awareness'] = self.t / self.env.spec.max_episode_steps
return obs
def step(self, action):
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.t += 1
return super().step(action)
def reset(self, **kwargs):
"""Reset the environment setting the time to zero.
Args:
**kwargs: Kwargs to apply to env.reset()
Returns:
The reset environment
"""
self.t = 0
return super().reset(**kwargs)
class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Observation wrapper that flattens the observation.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import FlattenObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> env = FlattenObservation(env)
>>> env.observation_space.shape
(27648,)
>>> obs, _ = env.reset()
>>> obs.shape
(27648,)
"""
def __init__(self, env: gym.Env):
"""Flattens the observations of an environment.
Args:
env: The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
self.observation_space = flatten_space(env.observation_space)
def observation(self, observation):
"""Flattens an observation.
Args:
observation: The observation to flatten
Returns:
The flattened observation
"""
try:
return flatten(self.env.observation_space, observation)
except:
return np.array([flatten(self.env.observation_space, observation[i]) for i in range(len(observation))])

View File

@ -9,7 +9,7 @@ from gymnasium.core import ActType, ObsType
import fancy_gym
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
from fancy_gym.utils.wrappers import TimeAwareObservation
SEED = 1
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
@ -17,6 +17,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
MAX_STEPS_FALLBACK = 500
class Object(object):
pass
@ -115,11 +117,6 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
if not env.spec.max_episode_steps:
# Not all envs expose a max_episode_steps.
# To use those with MPs, they could be put in a time_limit-wrapper.
return True
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
env_id, wrapper_class = env_wrap
@ -127,7 +124,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
{'basis_generator_type': basis_generator_type})
{'basis_generator_type': basis_generator_type}, fallback_max_steps=MAX_STEPS_FALLBACK)
for i in range(5):
env.reset(seed=SEED)

View File

@ -11,7 +11,8 @@ from gymnasium import spaces
import fancy_gym
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
from fancy_gym.utils.wrappers import TimeAwareObservation
from fancy_gym.utils.make_env_helpers import ensure_finite_time
SEED = 1
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
@ -19,6 +20,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
MAX_STEPS_FALLBACK = 100
class ToyEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
@ -64,7 +67,7 @@ def setup():
def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
add_time_aware_wrapper_before: bool):
env_id, wrapper_class = env_wrap
env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED))
env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK))
wrappers = [wrapper_class]
# has time aware wrapper
@ -75,15 +78,14 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
{'basis_generator_type': 'rbf'}, seed=SEED)
{'basis_generator_type': 'rbf'}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
assert env.learn_sub_trajectories
assert env.spec.max_episode_steps
assert env_step.spec.max_episode_steps
assert env.traj_gen.learn_tau
# This also verifies we are not adding the TimeAwareObservationWrapper twice
if env.observation_space.__class__ in [spaces.Dict]:
assert spaces.flatten_space(env.observation_space) == env_step.observation_space
else:
assert env.observation_space == env_step.observation_space
assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space)
done = True
@ -112,7 +114,7 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
add_time_aware_wrapper_before: bool, replanning_time: int):
env_id, wrapper_class = env_wrap
env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED))
env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK))
wrappers = [wrapper_class]
# has time aware wrapper
@ -128,15 +130,14 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': phase_generator_type},
{'basis_generator_type': basis_generator_type}, seed=SEED)
{'basis_generator_type': basis_generator_type}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
assert env.do_replanning
assert env.spec.max_episode_steps
assert env_step.spec.max_episode_steps
assert callable(env.replanning_schedule)
# This also verifies we are not adding the TimeAwareObservationWrapper twice
if env.observation_space.__class__ in [spaces.Dict]:
assert spaces.flatten_space(env.observation_space) == env_step.observation_space
else:
assert env.observation_space == env_step.observation_space
assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space)
env.reset(seed=SEED)
@ -177,7 +178,7 @@ def test_max_planning_times(mp_type: str, max_planning_times: int, sub_segment_s
},
{'basis_generator_type': basis_generator_type,
},
seed=SEED)
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
_ = env.reset(seed=SEED)
done = False
planning_times = 0
@ -209,7 +210,7 @@ def test_replanning_with_learn_tau(mp_type: str, max_planning_times: int, sub_se
},
{'basis_generator_type': basis_generator_type,
},
seed=SEED)
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
_ = env.reset(seed=SEED)
done = False
planning_times = 0
@ -242,7 +243,7 @@ def test_replanning_with_learn_delay(mp_type: str, max_planning_times: int, sub_
},
{'basis_generator_type': basis_generator_type,
},
seed=SEED)
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
_ = env.reset(seed=SEED)
done = False
planning_times = 0
@ -297,7 +298,7 @@ def test_replanning_with_learn_delay_and_tau(mp_type: str, max_planning_times: i
},
{'basis_generator_type': basis_generator_type,
},
seed=SEED)
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
_ = env.reset(seed=SEED)
done = False
planning_times = 0
@ -346,7 +347,7 @@ def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_
},
{'basis_generator_type': basis_generator_type,
},
seed=SEED)
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
_ = env.reset(seed=SEED)
for i in range(max_planning_times):
action = env.action_space.sample()