timeaware wrapper

This commit is contained in:
Fabian 2022-07-12 10:06:38 +02:00
parent 123915e4fa
commit 8dba7f199b
3 changed files with 121 additions and 59 deletions

View File

@ -0,0 +1,78 @@
"""
Adapted from: https://github.com/openai/gym/blob/907b1b20dd9ac0cba5803225059b9c6673702467/gym/wrappers/time_aware_observation.py
License: MIT
Copyright (c) 2016 OpenAI (https://openai.com)
Wrapper for adding time aware observations to environment observation.
"""
import numpy as np
import gym
from gym.spaces import Box
class TimeAwareObservation(gym.ObservationWrapper):
"""Augment the observation with the current time step in the episode.
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
In particular, pixel observations are not supported. This wrapper will append the current timestep
within the current episode to the observation.
Example:
>>> import gym
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservation(env)
>>> env.reset()
array([ 0.03810719, 0.03522411, 0.02231044, -0.01088205, 0. ])
>>> env.step(env.action_space.sample())[0]
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
"""
def __init__(self, env: gym.Env):
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box`
observation space.
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
assert isinstance(env.observation_space, Box)
low = np.append(self.observation_space.low, 0.0)
high = np.append(self.observation_space.high, np.inf)
self.observation_space = Box(low, high, dtype=self.observation_space.dtype)
self.t = 0
def observation(self, observation):
"""Adds to the observation with the current time step.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time step appended to
"""
return np.append(observation, self.t)
def step(self, action):
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.t += 1
return super().step(action)
def reset(self, **kwargs):
"""Reset the environment setting the time to zero.
Args:
**kwargs: Kwargs to apply to env.reset()
Returns:
The reset environment
"""
self.t = 0
return super().reset(**kwargs)

View File

@ -1,13 +1,11 @@
import numpy as np
import alr_envs import alr_envs
def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, render=True): def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True):
""" """
Example for running a motion primitive based environment, which is already registered Example for running a black box based environment, which is already registered
Args: Args:
env_name: DMP env_id env_name: Black box env_id
seed: seed for deterministic behaviour seed: seed for deterministic behaviour
iterations: Number of rollout steps to run iterations: Number of rollout steps to run
render: Render the episode render: Render the episode
@ -15,8 +13,8 @@ def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, rend
Returns: Returns:
""" """
# While in this case gym.make() is possible to use as well, we recommend our custom make env function. # Equivalent to gym, we have make function which can be used to create environments.
# First, it already takes care of seeding and second enables the use of DMC tasks within the gym interface. # It takes care of seeding and enables the use of a variety of external environments using the gym interface.
env = alr_envs.make(env_name, seed) env = alr_envs.make(env_name, seed)
rewards = 0 rewards = 0
@ -37,8 +35,12 @@ def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, rend
else: else:
env.render(mode=None) env.render(mode=None)
# Now the action space is not the raw action but the parametrization of the trajectory generator,
# such as a ProMP
ac = env.action_space.sample() ac = env.action_space.sample()
# This executes a full trajectory
obs, reward, done, info = env.step(ac) obs, reward, done, info = env.step(ac)
# Aggregated reward
rewards += reward rewards += reward
if done: if done:

View File

@ -7,8 +7,7 @@ from typing import Iterable, Type, Union
import gym import gym
import numpy as np import numpy as np
from gym.envs.registration import register, registry
import alr_envs
try: try:
from dm_control import suite, manipulation, composer from dm_control import suite, manipulation, composer
@ -22,20 +21,16 @@ except Exception:
# catch Exception due to Mujoco-py # catch Exception due to Mujoco-py
pass pass
from gym.envs.registration import registry import alr_envs
from gym.envs.registration import register
from gym.wrappers import TimeAwareObservation
from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper
from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator
from alr_envs.black_box.factory.controller_factory import get_controller from alr_envs.black_box.factory.controller_factory import get_controller
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
from alr_envs.black_box.time_aware_observation import TimeAwareObservation
from alr_envs.utils.utils import nested_update from alr_envs.utils.utils import nested_update
ALL_FRAMEWORK_TYPES = ['meta', 'dmc', 'gym']
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
""" """
@ -63,34 +58,21 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
return f if return_callable else f() return f if return_callable else f()
def make(env_id, seed, **kwargs): def make(env_id: str, seed: int, **kwargs):
return _make(env_id, seed, **kwargs)
def _make(env_id: str, seed, **kwargs):
""" """
Converts an env_id to an environment with the gym API. Converts an env_id to an environment with the gym API.
This also works for DeepMind Control Suite interface_wrappers This also works for DeepMind Control Suite environments that are wrapped using the DMCWrapper, they can be
for which domain name and task name are expected to be separated by "-". specified with "dmc:domain_name-task_name"
Analogously, metaworld tasks can be created as "metaworld:env_id-v2".
Args: Args:
env_id: gym name or env_id of the form "domain_name-task_name" for DMC tasks env_id: spec or env_id for gym tasks, external environments require a domain specification
**kwargs: Additional kwargs for the constructor such as pixel observations, etc. **kwargs: Additional kwargs for the constructor such as pixel observations, etc.
Returns: Gym environment Returns: Gym environment
""" """
# 'dmc:domain-task'
# 'gym:name-vX'
# 'meta:name-vX'
# 'meta:bb:name-vX'
# 'hand:name-vX'
# 'name-vX'
# 'bb:name-vX'
#
# env_id.split(':')
# if 'dmc' :
if ':' in env_id: if ':' in env_id:
split_id = env_id.split(':') split_id = env_id.split(':')
framework, env_id = split_id[-2:] framework, env_id = split_id[-2:]
@ -98,13 +80,17 @@ def _make(env_id: str, seed, **kwargs):
framework = None framework = None
if framework == 'metaworld': if framework == 'metaworld':
# MetaWorld env # MetaWorld environment
env = make_metaworld(env_id, seed=seed, **kwargs) env = make_metaworld(env_id, seed, **kwargs)
elif framework == 'dmc': elif framework == 'dmc':
# DeepMind Controlp # DeepMind Control environment
env = make_dmc(env_id, seed=seed, **kwargs) env = make_dmc(env_id, seed, **kwargs)
else: else:
env = make_gym(env_id, seed=seed, **kwargs) env = make_gym(env_id, seed, **kwargs)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env return env
@ -285,7 +271,7 @@ def make_dmc(
) )
env = gym.make(gym_id) env = gym.make(gym_id)
env.seed(seed=seed) env.seed(seed)
return env return env
@ -300,15 +286,6 @@ def make_metaworld(env_id, seed, **kwargs):
# New argument to use global seeding # New argument to use global seeding
_env.seeded_rand_vec = True _env.seeded_rand_vec = True
# Manually set spec, as metaworld environments are not registered via gym
# _env.unwrapped.spec = EnvSpec(env_id)
# Set Timelimit based on the maximum allowed path length of the environment
# _env = gym.wrappers.TimeLimit(_env, max_episode_steps=_env.max_path_length)
# _env.seed(seed)
# _env.action_space.seed(seed)
# _env.observation_space.seed(seed)
# _env.goal_space.seed(seed)
gym_id = uuid.uuid4().hex + '-v1' gym_id = uuid.uuid4().hex + '-v1'
register( register(
@ -319,28 +296,33 @@ def make_metaworld(env_id, seed, **kwargs):
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld # TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
env = gym.make(gym_id, disable_env_checker=True) env = gym.make(gym_id, disable_env_checker=True)
env.seed(seed=seed)
return env return env
def make_gym(env_id, seed, **kwargs): def make_gym(env_id, seed, **kwargs):
# This access is required to allow for nested dict updates for BB envs """
spec = registry.get(env_id) Create
all_kwargs = deepcopy(spec.kwargs) Args:
env_id:
seed:
**kwargs:
Returns:
"""
# Getting the existing keywords to allow for nested dict updates for BB envs
# gym only allows for non nested updates.
all_kwargs = deepcopy(registry.get(env_id).kwargs)
nested_update(all_kwargs, kwargs) nested_update(all_kwargs, kwargs)
kwargs = all_kwargs kwargs = all_kwargs
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment. # Add seed to kwargs for bb environments to pass seed to step environments
# if env_id.startswith("dmc") or any(s in env_id.lower() for s in ['promp', 'dmp', 'prodmp']):
all_bb_envs = sum(alr_envs.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values(), []) all_bb_envs = sum(alr_envs.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values(), [])
if env_id.startswith("dmc") or env_id in all_bb_envs: if env_id in all_bb_envs:
kwargs.update({"seed": seed}) kwargs.update({"seed": seed})
# Gym # Gym
env = gym.make(env_id, **kwargs) env = gym.make(env_id, **kwargs)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env return env