diff --git a/fancy_gym/utils/wrappers.py b/fancy_gym/utils/wrappers.py new file mode 100644 index 0000000..03542c7 --- /dev/null +++ b/fancy_gym/utils/wrappers.py @@ -0,0 +1,127 @@ +from gymnasium.spaces import Box, Dict, flatten, flatten_space +from gym.spaces import Box as OldBox +import gymnasium as gym +import numpy as np +import copy + + +class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): + """Augment the observation with the current time step in the episode. + + The observation space of the wrapped environment is assumed to be a flat :class:`Box` or flattable :class:`Dict`. + In particular, pixel observations are not supported. This wrapper will append the current progress within the current episode to the observation. + The progress will be indicated as a number between 0 and 1. + """ + + def __init__(self, env: gym.Env, enforce_dtype_float32=False): + """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` or flattable :class:`Dict` observation space. + + Args: + env: The environment to apply the wrapper + """ + gym.utils.RecordConstructorArgs.__init__(self) + gym.ObservationWrapper.__init__(self, env) + allowed_classes = [Box, OldBox, Dict] + if enforce_dtype_float32: + assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str( + env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.' + assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict' + + if env.observation_space.__class__ in [Box, OldBox]: + dtype = env.observation_space.dtype + + low = np.append(env.observation_space.low, 0.0) + high = np.append(env.observation_space.high, 1.0) + + self.observation_space = Box(low, high, dtype=dtype) + else: + spaces = copy.copy(env.observation_space.spaces) + dtype = np.float64 + spaces['time_awareness'] = Box(0, 1, dtype=dtype) + + self.observation_space = Dict(spaces) + + self.is_vector_env = getattr(env, "is_vector_env", False) + + def observation(self, observation): + """Adds to the observation with the current time step. + + Args: + observation: The observation to add the time step to + + Returns: + The observation with the time step appended to (relative to total number of steps) + """ + if self.observation_space.__class__ in [Box, OldBox]: + return np.append(observation, self.t / self.env.spec.max_episode_steps) + else: + obs = copy.copy(observation) + obs['time_awareness'] = self.t / self.env.spec.max_episode_steps + return obs + + def step(self, action): + """Steps through the environment, incrementing the time step. + + Args: + action: The action to take + + Returns: + The environment's step using the action. + """ + self.t += 1 + return super().step(action) + + def reset(self, **kwargs): + """Reset the environment setting the time to zero. + + Args: + **kwargs: Kwargs to apply to env.reset() + + Returns: + The reset environment + """ + self.t = 0 + return super().reset(**kwargs) + + +class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): + """Observation wrapper that flattens the observation. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import FlattenObservation + >>> env = gym.make("CarRacing-v2") + >>> env.observation_space.shape + (96, 96, 3) + >>> env = FlattenObservation(env) + >>> env.observation_space.shape + (27648,) + >>> obs, _ = env.reset() + >>> obs.shape + (27648,) + """ + + def __init__(self, env: gym.Env): + """Flattens the observations of an environment. + + Args: + env: The environment to apply the wrapper + """ + gym.utils.RecordConstructorArgs.__init__(self) + gym.ObservationWrapper.__init__(self, env) + + self.observation_space = flatten_space(env.observation_space) + + def observation(self, observation): + """Flattens an observation. + + Args: + observation: The observation to flatten + + Returns: + The flattened observation + """ + try: + return flatten(self.env.observation_space, observation) + except: + return np.array([flatten(self.env.observation_space, observation[i]) for i in range(len(observation))]) diff --git a/test/test_black_box.py b/test/test_black_box.py index 61926cf..1492958 100644 --- a/test/test_black_box.py +++ b/test/test_black_box.py @@ -9,7 +9,7 @@ from gymnasium.core import ActType, ObsType import fancy_gym from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper -from fancy_gym.utils.time_aware_observation import TimeAwareObservation +from fancy_gym.utils.wrappers import TimeAwareObservation SEED = 1 ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2'] @@ -17,6 +17,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper] ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) +MAX_STEPS_FALLBACK = 500 + class Object(object): pass @@ -115,11 +117,6 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]] @pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp']) @pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS)) def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]): - if not env.spec.max_episode_steps: - # Not all envs expose a max_episode_steps. - # To use those with MPs, they could be put in a time_limit-wrapper. - return True - basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf' env_id, wrapper_class = env_wrap @@ -127,7 +124,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]): {'trajectory_generator_type': mp_type}, {'controller_type': 'motor'}, {'phase_generator_type': 'exp'}, - {'basis_generator_type': basis_generator_type}) + {'basis_generator_type': basis_generator_type}, fallback_max_steps=MAX_STEPS_FALLBACK) for i in range(5): env.reset(seed=SEED) diff --git a/test/test_replanning_sequencing.py b/test/test_replanning_sequencing.py index 6425ef8..f219bbb 100644 --- a/test/test_replanning_sequencing.py +++ b/test/test_replanning_sequencing.py @@ -11,7 +11,8 @@ from gymnasium import spaces import fancy_gym from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper -from fancy_gym.utils.time_aware_observation import TimeAwareObservation +from fancy_gym.utils.wrappers import TimeAwareObservation +from fancy_gym.utils.make_env_helpers import ensure_finite_time SEED = 1 ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2'] @@ -19,6 +20,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper] ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) +MAX_STEPS_FALLBACK = 100 + class ToyEnv(gym.Env): observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64) @@ -64,7 +67,7 @@ def setup(): def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]], add_time_aware_wrapper_before: bool): env_id, wrapper_class = env_wrap - env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED)) + env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK)) wrappers = [wrapper_class] # has time aware wrapper @@ -75,15 +78,14 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter {'trajectory_generator_type': mp_type}, {'controller_type': 'motor'}, {'phase_generator_type': 'exp'}, - {'basis_generator_type': 'rbf'}, seed=SEED) + {'basis_generator_type': 'rbf'}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK) assert env.learn_sub_trajectories + assert env.spec.max_episode_steps + assert env_step.spec.max_episode_steps assert env.traj_gen.learn_tau # This also verifies we are not adding the TimeAwareObservationWrapper twice - if env.observation_space.__class__ in [spaces.Dict]: - assert spaces.flatten_space(env.observation_space) == env_step.observation_space - else: - assert env.observation_space == env_step.observation_space + assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space) done = True @@ -112,7 +114,7 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]], add_time_aware_wrapper_before: bool, replanning_time: int): env_id, wrapper_class = env_wrap - env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED)) + env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK)) wrappers = [wrapper_class] # has time aware wrapper @@ -128,15 +130,14 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra {'trajectory_generator_type': mp_type}, {'controller_type': 'motor'}, {'phase_generator_type': phase_generator_type}, - {'basis_generator_type': basis_generator_type}, seed=SEED) + {'basis_generator_type': basis_generator_type}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK) assert env.do_replanning + assert env.spec.max_episode_steps + assert env_step.spec.max_episode_steps assert callable(env.replanning_schedule) # This also verifies we are not adding the TimeAwareObservationWrapper twice - if env.observation_space.__class__ in [spaces.Dict]: - assert spaces.flatten_space(env.observation_space) == env_step.observation_space - else: - assert env.observation_space == env_step.observation_space + assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space) env.reset(seed=SEED) @@ -177,7 +178,7 @@ def test_max_planning_times(mp_type: str, max_planning_times: int, sub_segment_s }, {'basis_generator_type': basis_generator_type, }, - seed=SEED) + seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK) _ = env.reset(seed=SEED) done = False planning_times = 0 @@ -209,7 +210,7 @@ def test_replanning_with_learn_tau(mp_type: str, max_planning_times: int, sub_se }, {'basis_generator_type': basis_generator_type, }, - seed=SEED) + seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK) _ = env.reset(seed=SEED) done = False planning_times = 0 @@ -242,7 +243,7 @@ def test_replanning_with_learn_delay(mp_type: str, max_planning_times: int, sub_ }, {'basis_generator_type': basis_generator_type, }, - seed=SEED) + seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK) _ = env.reset(seed=SEED) done = False planning_times = 0 @@ -297,7 +298,7 @@ def test_replanning_with_learn_delay_and_tau(mp_type: str, max_planning_times: i }, {'basis_generator_type': basis_generator_type, }, - seed=SEED) + seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK) _ = env.reset(seed=SEED) done = False planning_times = 0 @@ -346,7 +347,7 @@ def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_ }, {'basis_generator_type': basis_generator_type, }, - seed=SEED) + seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK) _ = env.reset(seed=SEED) for i in range(max_planning_times): action = env.action_space.sample()