Fix: Make wrappers work with BB and Dict-Space
This commit is contained in:
parent
b032dec5fe
commit
9ade0dcdc4
127
fancy_gym/utils/wrappers.py
Normal file
127
fancy_gym/utils/wrappers.py
Normal file
@ -0,0 +1,127 @@
|
||||
from gymnasium.spaces import Box, Dict, flatten, flatten_space
|
||||
from gym.spaces import Box as OldBox
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
|
||||
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Augment the observation with the current time step in the episode.
|
||||
|
||||
The observation space of the wrapped environment is assumed to be a flat :class:`Box` or flattable :class:`Dict`.
|
||||
In particular, pixel observations are not supported. This wrapper will append the current progress within the current episode to the observation.
|
||||
The progress will be indicated as a number between 0 and 1.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, enforce_dtype_float32=False):
|
||||
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` or flattable :class:`Dict` observation space.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
allowed_classes = [Box, OldBox, Dict]
|
||||
if enforce_dtype_float32:
|
||||
assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str(
|
||||
env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.'
|
||||
assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict'
|
||||
|
||||
if env.observation_space.__class__ in [Box, OldBox]:
|
||||
dtype = env.observation_space.dtype
|
||||
|
||||
low = np.append(env.observation_space.low, 0.0)
|
||||
high = np.append(env.observation_space.high, 1.0)
|
||||
|
||||
self.observation_space = Box(low, high, dtype=dtype)
|
||||
else:
|
||||
spaces = copy.copy(env.observation_space.spaces)
|
||||
dtype = np.float64
|
||||
spaces['time_awareness'] = Box(0, 1, dtype=dtype)
|
||||
|
||||
self.observation_space = Dict(spaces)
|
||||
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
|
||||
def observation(self, observation):
|
||||
"""Adds to the observation with the current time step.
|
||||
|
||||
Args:
|
||||
observation: The observation to add the time step to
|
||||
|
||||
Returns:
|
||||
The observation with the time step appended to (relative to total number of steps)
|
||||
"""
|
||||
if self.observation_space.__class__ in [Box, OldBox]:
|
||||
return np.append(observation, self.t / self.env.spec.max_episode_steps)
|
||||
else:
|
||||
obs = copy.copy(observation)
|
||||
obs['time_awareness'] = self.t / self.env.spec.max_episode_steps
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, incrementing the time step.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The environment's step using the action.
|
||||
"""
|
||||
self.t += 1
|
||||
return super().step(action)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment setting the time to zero.
|
||||
|
||||
Args:
|
||||
**kwargs: Kwargs to apply to env.reset()
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self.t = 0
|
||||
return super().reset(**kwargs)
|
||||
|
||||
|
||||
class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Observation wrapper that flattens the observation.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import FlattenObservation
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> env = FlattenObservation(env)
|
||||
>>> env.observation_space.shape
|
||||
(27648,)
|
||||
>>> obs, _ = env.reset()
|
||||
>>> obs.shape
|
||||
(27648,)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""Flattens the observations of an environment.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.observation_space = flatten_space(env.observation_space)
|
||||
|
||||
def observation(self, observation):
|
||||
"""Flattens an observation.
|
||||
|
||||
Args:
|
||||
observation: The observation to flatten
|
||||
|
||||
Returns:
|
||||
The flattened observation
|
||||
"""
|
||||
try:
|
||||
return flatten(self.env.observation_space, observation)
|
||||
except:
|
||||
return np.array([flatten(self.env.observation_space, observation[i]) for i in range(len(observation))])
|
@ -9,7 +9,7 @@ from gymnasium.core import ActType, ObsType
|
||||
|
||||
import fancy_gym
|
||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
|
||||
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||
|
||||
SEED = 1
|
||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||
@ -17,6 +17,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
|
||||
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
||||
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
|
||||
MAX_STEPS_FALLBACK = 500
|
||||
|
||||
|
||||
class Object(object):
|
||||
pass
|
||||
@ -115,11 +117,6 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]
|
||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
||||
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
|
||||
def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||
if not env.spec.max_episode_steps:
|
||||
# Not all envs expose a max_episode_steps.
|
||||
# To use those with MPs, they could be put in a time_limit-wrapper.
|
||||
return True
|
||||
|
||||
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
|
||||
|
||||
env_id, wrapper_class = env_wrap
|
||||
@ -127,7 +124,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||
{'trajectory_generator_type': mp_type},
|
||||
{'controller_type': 'motor'},
|
||||
{'phase_generator_type': 'exp'},
|
||||
{'basis_generator_type': basis_generator_type})
|
||||
{'basis_generator_type': basis_generator_type}, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
|
||||
for i in range(5):
|
||||
env.reset(seed=SEED)
|
||||
|
@ -11,7 +11,8 @@ from gymnasium import spaces
|
||||
|
||||
import fancy_gym
|
||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
|
||||
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||
from fancy_gym.utils.make_env_helpers import ensure_finite_time
|
||||
|
||||
SEED = 1
|
||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||
@ -19,6 +20,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
|
||||
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
||||
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
|
||||
MAX_STEPS_FALLBACK = 100
|
||||
|
||||
|
||||
class ToyEnv(gym.Env):
|
||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
||||
@ -64,7 +67,7 @@ def setup():
|
||||
def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
|
||||
add_time_aware_wrapper_before: bool):
|
||||
env_id, wrapper_class = env_wrap
|
||||
env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED))
|
||||
env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK))
|
||||
wrappers = [wrapper_class]
|
||||
|
||||
# has time aware wrapper
|
||||
@ -75,15 +78,14 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
|
||||
{'trajectory_generator_type': mp_type},
|
||||
{'controller_type': 'motor'},
|
||||
{'phase_generator_type': 'exp'},
|
||||
{'basis_generator_type': 'rbf'}, seed=SEED)
|
||||
{'basis_generator_type': 'rbf'}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
|
||||
assert env.learn_sub_trajectories
|
||||
assert env.spec.max_episode_steps
|
||||
assert env_step.spec.max_episode_steps
|
||||
assert env.traj_gen.learn_tau
|
||||
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
||||
if env.observation_space.__class__ in [spaces.Dict]:
|
||||
assert spaces.flatten_space(env.observation_space) == env_step.observation_space
|
||||
else:
|
||||
assert env.observation_space == env_step.observation_space
|
||||
assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space)
|
||||
|
||||
done = True
|
||||
|
||||
@ -112,7 +114,7 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
|
||||
def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
|
||||
add_time_aware_wrapper_before: bool, replanning_time: int):
|
||||
env_id, wrapper_class = env_wrap
|
||||
env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED))
|
||||
env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK))
|
||||
wrappers = [wrapper_class]
|
||||
|
||||
# has time aware wrapper
|
||||
@ -128,15 +130,14 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
|
||||
{'trajectory_generator_type': mp_type},
|
||||
{'controller_type': 'motor'},
|
||||
{'phase_generator_type': phase_generator_type},
|
||||
{'basis_generator_type': basis_generator_type}, seed=SEED)
|
||||
{'basis_generator_type': basis_generator_type}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
|
||||
assert env.do_replanning
|
||||
assert env.spec.max_episode_steps
|
||||
assert env_step.spec.max_episode_steps
|
||||
assert callable(env.replanning_schedule)
|
||||
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
||||
if env.observation_space.__class__ in [spaces.Dict]:
|
||||
assert spaces.flatten_space(env.observation_space) == env_step.observation_space
|
||||
else:
|
||||
assert env.observation_space == env_step.observation_space
|
||||
assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space)
|
||||
|
||||
env.reset(seed=SEED)
|
||||
|
||||
@ -177,7 +178,7 @@ def test_max_planning_times(mp_type: str, max_planning_times: int, sub_segment_s
|
||||
},
|
||||
{'basis_generator_type': basis_generator_type,
|
||||
},
|
||||
seed=SEED)
|
||||
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
_ = env.reset(seed=SEED)
|
||||
done = False
|
||||
planning_times = 0
|
||||
@ -209,7 +210,7 @@ def test_replanning_with_learn_tau(mp_type: str, max_planning_times: int, sub_se
|
||||
},
|
||||
{'basis_generator_type': basis_generator_type,
|
||||
},
|
||||
seed=SEED)
|
||||
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
_ = env.reset(seed=SEED)
|
||||
done = False
|
||||
planning_times = 0
|
||||
@ -242,7 +243,7 @@ def test_replanning_with_learn_delay(mp_type: str, max_planning_times: int, sub_
|
||||
},
|
||||
{'basis_generator_type': basis_generator_type,
|
||||
},
|
||||
seed=SEED)
|
||||
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
_ = env.reset(seed=SEED)
|
||||
done = False
|
||||
planning_times = 0
|
||||
@ -297,7 +298,7 @@ def test_replanning_with_learn_delay_and_tau(mp_type: str, max_planning_times: i
|
||||
},
|
||||
{'basis_generator_type': basis_generator_type,
|
||||
},
|
||||
seed=SEED)
|
||||
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
_ = env.reset(seed=SEED)
|
||||
done = False
|
||||
planning_times = 0
|
||||
@ -346,7 +347,7 @@ def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_
|
||||
},
|
||||
{'basis_generator_type': basis_generator_type,
|
||||
},
|
||||
seed=SEED)
|
||||
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||
_ = env.reset(seed=SEED)
|
||||
for i in range(max_planning_times):
|
||||
action = env.action_space.sample()
|
||||
|
Loading…
Reference in New Issue
Block a user