Merge remote-tracking branch 'origin/clean_api' into clean_api
This commit is contained in:
commit
da49d1b7f7
78
alr_envs/black_box/time_aware_observation.py
Normal file
78
alr_envs/black_box/time_aware_observation.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
Adapted from: https://github.com/openai/gym/blob/907b1b20dd9ac0cba5803225059b9c6673702467/gym/wrappers/time_aware_observation.py
|
||||
License: MIT
|
||||
Copyright (c) 2016 OpenAI (https://openai.com)
|
||||
|
||||
Wrapper for adding time aware observations to environment observation.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
|
||||
|
||||
class TimeAwareObservation(gym.ObservationWrapper):
|
||||
"""Augment the observation with the current time step in the episode.
|
||||
|
||||
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
|
||||
In particular, pixel observations are not supported. This wrapper will append the current timestep
|
||||
within the current episode to the observation.
|
||||
|
||||
Example:
|
||||
>>> import gym
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservation(env)
|
||||
>>> env.reset()
|
||||
array([ 0.03810719, 0.03522411, 0.02231044, -0.01088205, 0. ])
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box`
|
||||
observation space.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
assert isinstance(env.observation_space, Box)
|
||||
low = np.append(self.observation_space.low, 0.0)
|
||||
high = np.append(self.observation_space.high, np.inf)
|
||||
self.observation_space = Box(low, high, dtype=self.observation_space.dtype)
|
||||
self.t = 0
|
||||
|
||||
def observation(self, observation):
|
||||
"""Adds to the observation with the current time step.
|
||||
|
||||
Args:
|
||||
observation: The observation to add the time step to
|
||||
|
||||
Returns:
|
||||
The observation with the time step appended to
|
||||
"""
|
||||
return np.append(observation, self.t)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, incrementing the time step.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The environment's step using the action.
|
||||
"""
|
||||
self.t += 1
|
||||
return super().step(action)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment setting the time to zero.
|
||||
|
||||
Args:
|
||||
**kwargs: Kwargs to apply to env.reset()
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self.t = 0
|
||||
return super().reset(**kwargs)
|
@ -1,13 +1,11 @@
|
||||
import numpy as np
|
||||
|
||||
import alr_envs
|
||||
|
||||
|
||||
def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, render=True):
|
||||
def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True):
|
||||
"""
|
||||
Example for running a motion primitive based environment, which is already registered
|
||||
Example for running a black box based environment, which is already registered
|
||||
Args:
|
||||
env_name: DMP env_id
|
||||
env_name: Black box env_id
|
||||
seed: seed for deterministic behaviour
|
||||
iterations: Number of rollout steps to run
|
||||
render: Render the episode
|
||||
@ -15,8 +13,8 @@ def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, rend
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# While in this case gym.make() is possible to use as well, we recommend our custom make env function.
|
||||
# First, it already takes care of seeding and second enables the use of DMC tasks within the gym interface.
|
||||
# Equivalent to gym, we have make function which can be used to create environments.
|
||||
# It takes care of seeding and enables the use of a variety of external environments using the gym interface.
|
||||
env = alr_envs.make(env_name, seed)
|
||||
|
||||
rewards = 0
|
||||
@ -37,8 +35,12 @@ def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, rend
|
||||
else:
|
||||
env.render(mode=None)
|
||||
|
||||
# Now the action space is not the raw action but the parametrization of the trajectory generator,
|
||||
# such as a ProMP
|
||||
ac = env.action_space.sample()
|
||||
# This executes a full trajectory
|
||||
obs, reward, done, info = env.step(ac)
|
||||
# Aggregated reward
|
||||
rewards += reward
|
||||
|
||||
if done:
|
||||
|
@ -7,8 +7,7 @@ from typing import Iterable, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import alr_envs
|
||||
from gym.envs.registration import register, registry
|
||||
|
||||
try:
|
||||
from dm_control import suite, manipulation, composer
|
||||
@ -22,20 +21,16 @@ except Exception:
|
||||
# catch Exception due to Mujoco-py
|
||||
pass
|
||||
|
||||
from gym.envs.registration import registry
|
||||
from gym.envs.registration import register
|
||||
from gym.wrappers import TimeAwareObservation
|
||||
|
||||
import alr_envs
|
||||
from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper
|
||||
from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator
|
||||
from alr_envs.black_box.factory.controller_factory import get_controller
|
||||
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
|
||||
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
from alr_envs.black_box.time_aware_observation import TimeAwareObservation
|
||||
from alr_envs.utils.utils import nested_update
|
||||
|
||||
ALL_FRAMEWORK_TYPES = ['meta', 'dmc', 'gym']
|
||||
|
||||
|
||||
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||
"""
|
||||
@ -63,34 +58,21 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
|
||||
return f if return_callable else f()
|
||||
|
||||
|
||||
def make(env_id, seed, **kwargs):
|
||||
return _make(env_id, seed, **kwargs)
|
||||
|
||||
|
||||
def _make(env_id: str, seed, **kwargs):
|
||||
def make(env_id: str, seed: int, **kwargs):
|
||||
"""
|
||||
Converts an env_id to an environment with the gym API.
|
||||
This also works for DeepMind Control Suite interface_wrappers
|
||||
for which domain name and task name are expected to be separated by "-".
|
||||
This also works for DeepMind Control Suite environments that are wrapped using the DMCWrapper, they can be
|
||||
specified with "dmc:domain_name-task_name"
|
||||
Analogously, metaworld tasks can be created as "metaworld:env_id-v2".
|
||||
|
||||
Args:
|
||||
env_id: gym name or env_id of the form "domain_name-task_name" for DMC tasks
|
||||
env_id: spec or env_id for gym tasks, external environments require a domain specification
|
||||
**kwargs: Additional kwargs for the constructor such as pixel observations, etc.
|
||||
|
||||
Returns: Gym environment
|
||||
|
||||
"""
|
||||
|
||||
# 'dmc:domain-task'
|
||||
# 'gym:name-vX'
|
||||
# 'meta:name-vX'
|
||||
# 'meta:bb:name-vX'
|
||||
# 'hand:name-vX'
|
||||
# 'name-vX'
|
||||
# 'bb:name-vX'
|
||||
#
|
||||
# env_id.split(':')
|
||||
# if 'dmc' :
|
||||
|
||||
if ':' in env_id:
|
||||
split_id = env_id.split(':')
|
||||
framework, env_id = split_id[-2:]
|
||||
@ -98,13 +80,17 @@ def _make(env_id: str, seed, **kwargs):
|
||||
framework = None
|
||||
|
||||
if framework == 'metaworld':
|
||||
# MetaWorld env
|
||||
env = make_metaworld(env_id, seed=seed, **kwargs)
|
||||
# MetaWorld environment
|
||||
env = make_metaworld(env_id, seed, **kwargs)
|
||||
elif framework == 'dmc':
|
||||
# DeepMind Controlp
|
||||
env = make_dmc(env_id, seed=seed, **kwargs)
|
||||
# DeepMind Control environment
|
||||
env = make_dmc(env_id, seed, **kwargs)
|
||||
else:
|
||||
env = make_gym(env_id, seed=seed, **kwargs)
|
||||
env = make_gym(env_id, seed, **kwargs)
|
||||
|
||||
env.seed(seed)
|
||||
env.action_space.seed(seed)
|
||||
env.observation_space.seed(seed)
|
||||
|
||||
return env
|
||||
|
||||
@ -285,7 +271,7 @@ def make_dmc(
|
||||
)
|
||||
|
||||
env = gym.make(gym_id)
|
||||
env.seed(seed=seed)
|
||||
env.seed(seed)
|
||||
return env
|
||||
|
||||
|
||||
@ -300,15 +286,6 @@ def make_metaworld(env_id, seed, **kwargs):
|
||||
# New argument to use global seeding
|
||||
_env.seeded_rand_vec = True
|
||||
|
||||
# Manually set spec, as metaworld environments are not registered via gym
|
||||
# _env.unwrapped.spec = EnvSpec(env_id)
|
||||
# Set Timelimit based on the maximum allowed path length of the environment
|
||||
# _env = gym.wrappers.TimeLimit(_env, max_episode_steps=_env.max_path_length)
|
||||
# _env.seed(seed)
|
||||
# _env.action_space.seed(seed)
|
||||
# _env.observation_space.seed(seed)
|
||||
# _env.goal_space.seed(seed)
|
||||
|
||||
gym_id = uuid.uuid4().hex + '-v1'
|
||||
|
||||
register(
|
||||
@ -319,28 +296,33 @@ def make_metaworld(env_id, seed, **kwargs):
|
||||
|
||||
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
|
||||
env = gym.make(gym_id, disable_env_checker=True)
|
||||
env.seed(seed=seed)
|
||||
return env
|
||||
|
||||
|
||||
def make_gym(env_id, seed, **kwargs):
|
||||
# This access is required to allow for nested dict updates for BB envs
|
||||
spec = registry.get(env_id)
|
||||
all_kwargs = deepcopy(spec.kwargs)
|
||||
"""
|
||||
Create
|
||||
Args:
|
||||
env_id:
|
||||
seed:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# Getting the existing keywords to allow for nested dict updates for BB envs
|
||||
# gym only allows for non nested updates.
|
||||
all_kwargs = deepcopy(registry.get(env_id).kwargs)
|
||||
nested_update(all_kwargs, kwargs)
|
||||
kwargs = all_kwargs
|
||||
|
||||
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
|
||||
# if env_id.startswith("dmc") or any(s in env_id.lower() for s in ['promp', 'dmp', 'prodmp']):
|
||||
# Add seed to kwargs for bb environments to pass seed to step environments
|
||||
all_bb_envs = sum(alr_envs.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values(), [])
|
||||
if env_id.startswith("dmc") or env_id in all_bb_envs:
|
||||
if env_id in all_bb_envs:
|
||||
kwargs.update({"seed": seed})
|
||||
|
||||
# Gym
|
||||
env = gym.make(env_id, **kwargs)
|
||||
env.seed(seed)
|
||||
env.action_space.seed(seed)
|
||||
env.observation_space.seed(seed)
|
||||
return env
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user