Fix: Make wrappers work with BB and Dict-Space
This commit is contained in:
parent
b032dec5fe
commit
9ade0dcdc4
127
fancy_gym/utils/wrappers.py
Normal file
127
fancy_gym/utils/wrappers.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from gymnasium.spaces import Box, Dict, flatten, flatten_space
|
||||||
|
from gym.spaces import Box as OldBox
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
"""Augment the observation with the current time step in the episode.
|
||||||
|
|
||||||
|
The observation space of the wrapped environment is assumed to be a flat :class:`Box` or flattable :class:`Dict`.
|
||||||
|
In particular, pixel observations are not supported. This wrapper will append the current progress within the current episode to the observation.
|
||||||
|
The progress will be indicated as a number between 0 and 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, enforce_dtype_float32=False):
|
||||||
|
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` or flattable :class:`Dict` observation space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to apply the wrapper
|
||||||
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
allowed_classes = [Box, OldBox, Dict]
|
||||||
|
if enforce_dtype_float32:
|
||||||
|
assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str(
|
||||||
|
env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.'
|
||||||
|
assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict'
|
||||||
|
|
||||||
|
if env.observation_space.__class__ in [Box, OldBox]:
|
||||||
|
dtype = env.observation_space.dtype
|
||||||
|
|
||||||
|
low = np.append(env.observation_space.low, 0.0)
|
||||||
|
high = np.append(env.observation_space.high, 1.0)
|
||||||
|
|
||||||
|
self.observation_space = Box(low, high, dtype=dtype)
|
||||||
|
else:
|
||||||
|
spaces = copy.copy(env.observation_space.spaces)
|
||||||
|
dtype = np.float64
|
||||||
|
spaces['time_awareness'] = Box(0, 1, dtype=dtype)
|
||||||
|
|
||||||
|
self.observation_space = Dict(spaces)
|
||||||
|
|
||||||
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
"""Adds to the observation with the current time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: The observation to add the time step to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The observation with the time step appended to (relative to total number of steps)
|
||||||
|
"""
|
||||||
|
if self.observation_space.__class__ in [Box, OldBox]:
|
||||||
|
return np.append(observation, self.t / self.env.spec.max_episode_steps)
|
||||||
|
else:
|
||||||
|
obs = copy.copy(observation)
|
||||||
|
obs['time_awareness'] = self.t / self.env.spec.max_episode_steps
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""Steps through the environment, incrementing the time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to take
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The environment's step using the action.
|
||||||
|
"""
|
||||||
|
self.t += 1
|
||||||
|
return super().step(action)
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
"""Reset the environment setting the time to zero.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Kwargs to apply to env.reset()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reset environment
|
||||||
|
"""
|
||||||
|
self.t = 0
|
||||||
|
return super().reset(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
"""Observation wrapper that flattens the observation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import gymnasium as gym
|
||||||
|
>>> from gymnasium.wrappers import FlattenObservation
|
||||||
|
>>> env = gym.make("CarRacing-v2")
|
||||||
|
>>> env.observation_space.shape
|
||||||
|
(96, 96, 3)
|
||||||
|
>>> env = FlattenObservation(env)
|
||||||
|
>>> env.observation_space.shape
|
||||||
|
(27648,)
|
||||||
|
>>> obs, _ = env.reset()
|
||||||
|
>>> obs.shape
|
||||||
|
(27648,)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env):
|
||||||
|
"""Flattens the observations of an environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to apply the wrapper
|
||||||
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
|
self.observation_space = flatten_space(env.observation_space)
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
"""Flattens an observation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: The observation to flatten
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The flattened observation
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return flatten(self.env.observation_space, observation)
|
||||||
|
except:
|
||||||
|
return np.array([flatten(self.env.observation_space, observation[i]) for i in range(len(observation))])
|
@ -9,7 +9,7 @@ from gymnasium.core import ActType, ObsType
|
|||||||
|
|
||||||
import fancy_gym
|
import fancy_gym
|
||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
|
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||||
|
|
||||||
SEED = 1
|
SEED = 1
|
||||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||||
@ -17,6 +17,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
|
|||||||
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
||||||
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||||
|
|
||||||
|
MAX_STEPS_FALLBACK = 500
|
||||||
|
|
||||||
|
|
||||||
class Object(object):
|
class Object(object):
|
||||||
pass
|
pass
|
||||||
@ -115,11 +117,6 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]
|
|||||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
||||||
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
|
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
|
||||||
def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||||
if not env.spec.max_episode_steps:
|
|
||||||
# Not all envs expose a max_episode_steps.
|
|
||||||
# To use those with MPs, they could be put in a time_limit-wrapper.
|
|
||||||
return True
|
|
||||||
|
|
||||||
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
|
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
|
||||||
|
|
||||||
env_id, wrapper_class = env_wrap
|
env_id, wrapper_class = env_wrap
|
||||||
@ -127,7 +124,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
|||||||
{'trajectory_generator_type': mp_type},
|
{'trajectory_generator_type': mp_type},
|
||||||
{'controller_type': 'motor'},
|
{'controller_type': 'motor'},
|
||||||
{'phase_generator_type': 'exp'},
|
{'phase_generator_type': 'exp'},
|
||||||
{'basis_generator_type': basis_generator_type})
|
{'basis_generator_type': basis_generator_type}, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
|
@ -11,7 +11,8 @@ from gymnasium import spaces
|
|||||||
|
|
||||||
import fancy_gym
|
import fancy_gym
|
||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
|
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||||
|
from fancy_gym.utils.make_env_helpers import ensure_finite_time
|
||||||
|
|
||||||
SEED = 1
|
SEED = 1
|
||||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||||
@ -19,6 +20,8 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
|
|||||||
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
||||||
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||||
|
|
||||||
|
MAX_STEPS_FALLBACK = 100
|
||||||
|
|
||||||
|
|
||||||
class ToyEnv(gym.Env):
|
class ToyEnv(gym.Env):
|
||||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
||||||
@ -64,7 +67,7 @@ def setup():
|
|||||||
def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
|
def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
|
||||||
add_time_aware_wrapper_before: bool):
|
add_time_aware_wrapper_before: bool):
|
||||||
env_id, wrapper_class = env_wrap
|
env_id, wrapper_class = env_wrap
|
||||||
env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED))
|
env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK))
|
||||||
wrappers = [wrapper_class]
|
wrappers = [wrapper_class]
|
||||||
|
|
||||||
# has time aware wrapper
|
# has time aware wrapper
|
||||||
@ -75,15 +78,14 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
|
|||||||
{'trajectory_generator_type': mp_type},
|
{'trajectory_generator_type': mp_type},
|
||||||
{'controller_type': 'motor'},
|
{'controller_type': 'motor'},
|
||||||
{'phase_generator_type': 'exp'},
|
{'phase_generator_type': 'exp'},
|
||||||
{'basis_generator_type': 'rbf'}, seed=SEED)
|
{'basis_generator_type': 'rbf'}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
|
|
||||||
assert env.learn_sub_trajectories
|
assert env.learn_sub_trajectories
|
||||||
|
assert env.spec.max_episode_steps
|
||||||
|
assert env_step.spec.max_episode_steps
|
||||||
assert env.traj_gen.learn_tau
|
assert env.traj_gen.learn_tau
|
||||||
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
||||||
if env.observation_space.__class__ in [spaces.Dict]:
|
assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space)
|
||||||
assert spaces.flatten_space(env.observation_space) == env_step.observation_space
|
|
||||||
else:
|
|
||||||
assert env.observation_space == env_step.observation_space
|
|
||||||
|
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
@ -112,7 +114,7 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
|
|||||||
def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
|
def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]],
|
||||||
add_time_aware_wrapper_before: bool, replanning_time: int):
|
add_time_aware_wrapper_before: bool, replanning_time: int):
|
||||||
env_id, wrapper_class = env_wrap
|
env_id, wrapper_class = env_wrap
|
||||||
env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED))
|
env_step = TimeAwareObservation(ensure_finite_time(fancy_gym.make(env_id, SEED), MAX_STEPS_FALLBACK))
|
||||||
wrappers = [wrapper_class]
|
wrappers = [wrapper_class]
|
||||||
|
|
||||||
# has time aware wrapper
|
# has time aware wrapper
|
||||||
@ -128,15 +130,14 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
|
|||||||
{'trajectory_generator_type': mp_type},
|
{'trajectory_generator_type': mp_type},
|
||||||
{'controller_type': 'motor'},
|
{'controller_type': 'motor'},
|
||||||
{'phase_generator_type': phase_generator_type},
|
{'phase_generator_type': phase_generator_type},
|
||||||
{'basis_generator_type': basis_generator_type}, seed=SEED)
|
{'basis_generator_type': basis_generator_type}, seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
|
|
||||||
assert env.do_replanning
|
assert env.do_replanning
|
||||||
|
assert env.spec.max_episode_steps
|
||||||
|
assert env_step.spec.max_episode_steps
|
||||||
assert callable(env.replanning_schedule)
|
assert callable(env.replanning_schedule)
|
||||||
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
||||||
if env.observation_space.__class__ in [spaces.Dict]:
|
assert spaces.flatten_space(env_step.observation_space) == spaces.flatten_space(env.observation_space)
|
||||||
assert spaces.flatten_space(env.observation_space) == env_step.observation_space
|
|
||||||
else:
|
|
||||||
assert env.observation_space == env_step.observation_space
|
|
||||||
|
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
@ -177,7 +178,7 @@ def test_max_planning_times(mp_type: str, max_planning_times: int, sub_segment_s
|
|||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
},
|
},
|
||||||
seed=SEED)
|
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
_ = env.reset(seed=SEED)
|
_ = env.reset(seed=SEED)
|
||||||
done = False
|
done = False
|
||||||
planning_times = 0
|
planning_times = 0
|
||||||
@ -209,7 +210,7 @@ def test_replanning_with_learn_tau(mp_type: str, max_planning_times: int, sub_se
|
|||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
},
|
},
|
||||||
seed=SEED)
|
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
_ = env.reset(seed=SEED)
|
_ = env.reset(seed=SEED)
|
||||||
done = False
|
done = False
|
||||||
planning_times = 0
|
planning_times = 0
|
||||||
@ -242,7 +243,7 @@ def test_replanning_with_learn_delay(mp_type: str, max_planning_times: int, sub_
|
|||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
},
|
},
|
||||||
seed=SEED)
|
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
_ = env.reset(seed=SEED)
|
_ = env.reset(seed=SEED)
|
||||||
done = False
|
done = False
|
||||||
planning_times = 0
|
planning_times = 0
|
||||||
@ -297,7 +298,7 @@ def test_replanning_with_learn_delay_and_tau(mp_type: str, max_planning_times: i
|
|||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
},
|
},
|
||||||
seed=SEED)
|
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
_ = env.reset(seed=SEED)
|
_ = env.reset(seed=SEED)
|
||||||
done = False
|
done = False
|
||||||
planning_times = 0
|
planning_times = 0
|
||||||
@ -346,7 +347,7 @@ def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_
|
|||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
},
|
},
|
||||||
seed=SEED)
|
seed=SEED, fallback_max_steps=MAX_STEPS_FALLBACK)
|
||||||
_ = env.reset(seed=SEED)
|
_ = env.reset(seed=SEED)
|
||||||
for i in range(max_planning_times):
|
for i in range(max_planning_times):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
|
Loading…
Reference in New Issue
Block a user