Merge pull request #12 from ALRhub/metaworld_integration
Metaworld integration
This commit is contained in:
commit
45ca0308c1
@ -1,12 +1,12 @@
|
|||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
from gym.wrappers import FlattenObservation
|
from gym.wrappers import FlattenObservation
|
||||||
|
|
||||||
from alr_envs import classic_control, dmc, open_ai
|
from alr_envs import classic_control, dmc, open_ai, meta
|
||||||
|
|
||||||
from alr_envs.utils.make_env_helpers import make_dmp_env
|
from alr_envs.utils.make_env_helpers import make_dmp_env
|
||||||
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||||
from alr_envs.utils.make_env_helpers import make_env
|
from alr_envs.utils.make_env_helpers import make
|
||||||
from alr_envs.utils.make_env_helpers import make_env_rank
|
from alr_envs.utils.make_env_helpers import make_rank
|
||||||
|
|
||||||
# Mujoco
|
# Mujoco
|
||||||
|
|
||||||
@ -305,18 +305,17 @@ register(
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"ball_in_cup-catch",
|
"name": f"ball_in_cup-catch",
|
||||||
"time_limit": 1,
|
"time_limit": 2,
|
||||||
"episode_length": 50,
|
"episode_length": 100,
|
||||||
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 1,
|
"duration": 2,
|
||||||
"learn_goal": True,
|
"learn_goal": True,
|
||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
"bandwidth_factor": 2,
|
"bandwidth_factor": 2,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 50,
|
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 50,
|
||||||
@ -331,16 +330,15 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"ball_in_cup-catch",
|
"name": f"ball_in_cup-catch",
|
||||||
"time_limit": 1,
|
"time_limit": 2,
|
||||||
"episode_length": 50,
|
"episode_length": 100,
|
||||||
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 1,
|
"duration": 2,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 50,
|
||||||
@ -828,6 +826,7 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 1.,
|
"p_gains": 1.,
|
||||||
@ -849,6 +848,7 @@ register(
|
|||||||
"duration": 1,
|
"duration": 1,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": .6,
|
"p_gains": .6,
|
||||||
@ -870,6 +870,25 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "position"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='FetchSlideDetPMP-v1',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": "gym.envs.robotics:FetchSlide-v1",
|
||||||
|
"wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -887,7 +906,127 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='FetchReachDetPMP-v1',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": "gym.envs.robotics:FetchReach-v1",
|
||||||
|
"wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "position"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# MetaWorld
|
||||||
|
|
||||||
|
goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||||
|
]
|
||||||
|
for env_id in goal_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
|
register(
|
||||||
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": env_id,
|
||||||
|
"wrappers": [meta.goal_change.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||||
|
for env_id in object_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
|
register(
|
||||||
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": env_id,
|
||||||
|
"wrappers": [meta.object_change.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||||
|
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||||
|
"coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2",
|
||||||
|
"door-lock-v2", "door-open-v2", "door-unlock-v2", "hand-insert-v2",
|
||||||
|
"drawer-close-v2", "drawer-open-v2", "faucet-open-v2", "faucet-close-v2",
|
||||||
|
"handle-press-side-v2", "handle-press-v2", "handle-pull-side-v2",
|
||||||
|
"handle-pull-v2", "lever-pull-v2", "peg-insert-side-v2", "pick-place-wall-v2",
|
||||||
|
"reach-v2", "push-back-v2", "push-v2", "pick-place-v2", "peg-unplug-side-v2",
|
||||||
|
"soccer-v2", "stick-push-v2", "stick-pull-v2", "push-wall-v2", "reach-wall-v2",
|
||||||
|
"shelf-place-v2", "sweep-v2", "window-open-v2", "window-close-v2"
|
||||||
|
]
|
||||||
|
for env_id in goal_and_object_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
|
register(
|
||||||
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": env_id,
|
||||||
|
"wrappers": [meta.goal_and_object_change.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||||
|
for env_id in goal_and_endeffector_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
|
register(
|
||||||
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": env_id,
|
||||||
|
"wrappers": [meta.goal_and_endeffector_change.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
3
alr_envs/dmc/README.MD
Normal file
3
alr_envs/dmc/README.MD
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# DeepMind Control (DMC) Wrappers
|
||||||
|
|
||||||
|
These are the Environment Wrappers for selected [DeepMind Control](https://deepmind.com/research/publications/2020/dm-control-Software-and-Tasks-for-Continuous-Control) environments in order to use our Motion Primitive gym interface with them.
|
@ -17,7 +17,7 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env = alr_envs.make_env(env_id, seed)
|
env = alr_envs.make(env_id, seed)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print("observation shape:", env.observation_space.shape)
|
print("observation shape:", env.observation_space.shape)
|
||||||
|
@ -21,7 +21,7 @@ def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
env = alr_envs.make_env(env_id, seed)
|
env = alr_envs.make(env_id, seed)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print("Observation shape: ", env.observation_space.shape)
|
print("Observation shape: ", env.observation_space.shape)
|
||||||
@ -56,7 +56,7 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16
|
|||||||
Returns: Tuple of (obs, reward, done, info) with type np.ndarray
|
Returns: Tuple of (obs, reward, done, info) with type np.ndarray
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env = gym.vector.AsyncVectorEnv([alr_envs.make_env_rank(env_id, seed, i) for i in range(n_cpu)])
|
env = gym.vector.AsyncVectorEnv([alr_envs.make_rank(env_id, seed, i) for i in range(n_cpu)])
|
||||||
# OR
|
# OR
|
||||||
# envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)])
|
# envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)])
|
||||||
|
|
||||||
@ -80,20 +80,21 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16
|
|||||||
rewards[done] = 0
|
rewards[done] = 0
|
||||||
|
|
||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return *map(lambda v: np.stack(v)[:n_samples], buffer.values()),
|
return (*map(lambda v: np.stack(v)[:n_samples], buffer.values()),)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = False
|
render = True
|
||||||
|
|
||||||
# Basic gym task
|
# Basic gym task
|
||||||
example_general("Pendulum-v0", seed=10, iterations=200, render=render)
|
example_general("Pendulum-v0", seed=10, iterations=200, render=render)
|
||||||
#
|
|
||||||
# # Basis task from framework
|
# # Basis task from framework
|
||||||
example_general("alr_envs:HoleReacher-v0", seed=10, iterations=200, render=render)
|
example_general("alr_envs:HoleReacher-v0", seed=10, iterations=200, render=render)
|
||||||
#
|
|
||||||
# # OpenAI Mujoco task
|
# # OpenAI Mujoco task
|
||||||
example_general("HalfCheetah-v2", seed=10, render=render)
|
example_general("HalfCheetah-v2", seed=10, render=render)
|
||||||
#
|
|
||||||
# # Mujoco task from framework
|
# # Mujoco task from framework
|
||||||
example_general("alr_envs:ALRReacher-v0", seed=10, iterations=200, render=render)
|
example_general("alr_envs:ALRReacher-v0", seed=10, iterations=200, render=render)
|
||||||
|
|
||||||
|
128
alr_envs/examples/examples_metaworld.py
Normal file
128
alr_envs/examples/examples_metaworld.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import alr_envs
|
||||||
|
from alr_envs.meta.goal_and_object_change import MPWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
||||||
|
"""
|
||||||
|
Example for running a MetaWorld based env in the step based setting.
|
||||||
|
The env_id has to be specified as `task_name-v2`. V1 versions are not supported and we always
|
||||||
|
return the observable goal version.
|
||||||
|
All tasks can be found here: https://arxiv.org/pdf/1910.10897.pdf or https://meta-world.github.io/
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_id: `task_name-v2`
|
||||||
|
seed: seed for deterministic behaviour (TODO: currently not working due to an issue in MetaWorld code)
|
||||||
|
iterations: Number of rollout steps to run
|
||||||
|
render: Render the episode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
env = alr_envs.make(env_id, seed)
|
||||||
|
rewards = 0
|
||||||
|
obs = env.reset()
|
||||||
|
print("observation shape:", env.observation_space.shape)
|
||||||
|
print("action shape:", env.action_space.shape)
|
||||||
|
|
||||||
|
for i in range(iterations):
|
||||||
|
ac = env.action_space.sample()
|
||||||
|
obs, reward, done, info = env.step(ac)
|
||||||
|
rewards += reward
|
||||||
|
|
||||||
|
if render:
|
||||||
|
# THIS NEEDS TO BE SET TO FALSE FOR NOW, BECAUSE THE INTERFACE FOR RENDERING IS DIFFERENT TO BASIC GYM
|
||||||
|
# TODO: Remove this, when Metaworld fixes its interface.
|
||||||
|
env.render(False)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
print(env_id, rewards)
|
||||||
|
rewards = 0
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
del env
|
||||||
|
|
||||||
|
|
||||||
|
def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||||
|
"""
|
||||||
|
Example for running a custom motion primitive based environments.
|
||||||
|
Our already registered environments follow the same structure.
|
||||||
|
Hence, this also allows to adjust hyperparameters of the motion primitives.
|
||||||
|
Yet, we recommend the method above if you are just interested in chaining those parameters for existing tasks.
|
||||||
|
We appreciate PRs for custom environments (especially MP wrappers of existing tasks)
|
||||||
|
for our repo: https://github.com/ALRhub/alr_envs/
|
||||||
|
Args:
|
||||||
|
seed: seed for deterministic behaviour (TODO: currently not working due to an issue in MetaWorld code)
|
||||||
|
iterations: Number of rollout steps to run
|
||||||
|
render: Render the episode (TODO: currently not working due to an issue in MetaWorld code)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Base MetaWorld name, according to structure of above example
|
||||||
|
base_env = "button-press-v2"
|
||||||
|
|
||||||
|
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||||
|
# You can also add other gym.Wrappers in case they are needed.
|
||||||
|
wrappers = [MPWrapper]
|
||||||
|
mp_kwargs = {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
|
||||||
|
env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
|
# OR for a DMP:
|
||||||
|
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
|
|
||||||
|
# This renders the full MP trajectory
|
||||||
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
|
# Resetting to no rendering, can be achieved by render(mode=None).
|
||||||
|
# It is also possible to change them mode multiple times when
|
||||||
|
# e.g. only every nth trajectory should be displayed.
|
||||||
|
if render:
|
||||||
|
raise ValueError("Metaworld render interface bug does not allow to render() fixes its interface. "
|
||||||
|
"A temporary workaround is to alter their code in MujocoEnv render() from "
|
||||||
|
"`if not offscreen` to `if not offscreen or offscreen == 'human'`.")
|
||||||
|
# TODO: Remove this, when Metaworld fixes its interface.
|
||||||
|
# env.render(mode="human")
|
||||||
|
|
||||||
|
rewards = 0
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
|
for i in range(iterations):
|
||||||
|
ac = env.action_space.sample()
|
||||||
|
obs, reward, done, info = env.step(ac)
|
||||||
|
rewards += reward
|
||||||
|
|
||||||
|
if done:
|
||||||
|
print(base_env, rewards)
|
||||||
|
rewards = 0
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
del env
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Disclaimer: MetaWorld environments require the seed to be specified in the beginning.
|
||||||
|
# Adjusting it afterwards with env.seed() is not recommended as it may not affect the underlying behavior.
|
||||||
|
|
||||||
|
# For rendering it might be necessary to specify your OpenGL installation
|
||||||
|
# export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so
|
||||||
|
render = False
|
||||||
|
|
||||||
|
# # Standard DMC Suite tasks
|
||||||
|
example_dmc("button-press-v2", seed=10, iterations=500, render=render)
|
||||||
|
|
||||||
|
# MP + MetaWorld hybrid task provided in the our framework
|
||||||
|
example_dmc("ButtonPressDetPMP-v2", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
|
# Custom MetaWorld task
|
||||||
|
example_custom_dmc_and_mp(seed=10, iterations=1, render=render)
|
@ -1,5 +1,4 @@
|
|||||||
from alr_envs import MPWrapper
|
import alr_envs
|
||||||
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
|
||||||
|
|
||||||
|
|
||||||
def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, render=True):
|
def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, render=True):
|
||||||
@ -16,7 +15,7 @@ def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=1, rend
|
|||||||
"""
|
"""
|
||||||
# While in this case gym.make() is possible to use as well, we recommend our custom make env function.
|
# While in this case gym.make() is possible to use as well, we recommend our custom make env function.
|
||||||
# First, it already takes care of seeding and second enables the use of DMC tasks within the gym interface.
|
# First, it already takes care of seeding and second enables the use of DMC tasks within the gym interface.
|
||||||
env = make_env(env_name, seed)
|
env = alr_envs.make(env_name, seed)
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
# env.render(mode=None)
|
# env.render(mode=None)
|
||||||
@ -71,7 +70,7 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1
|
"goal_scale": 0.1
|
||||||
}
|
}
|
||||||
env = make_env(env_name, seed, mp_kwargs=mp_kwargs)
|
env = alr_envs.make(env_name, seed, mp_kwargs=mp_kwargs)
|
||||||
|
|
||||||
# This time rendering every trajectory
|
# This time rendering every trajectory
|
||||||
if render:
|
if render:
|
||||||
@ -113,7 +112,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||||
# You can also add other gym.Wrappers in case they are needed.
|
# You can also add other gym.Wrappers in case they are needed.
|
||||||
wrappers = [MPWrapper]
|
wrappers = [alr_envs.classic_control.hole_reacher.MPWrapper]
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -125,7 +124,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1
|
"goal_scale": 0.1
|
||||||
}
|
}
|
||||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
# OR for a deterministic ProMP:
|
# OR for a deterministic ProMP:
|
||||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from alr_envs.utils.make_env_helpers import make_env
|
import alr_envs
|
||||||
|
|
||||||
|
|
||||||
def example_mp(env_name, seed=1):
|
def example_mp(env_name, seed=1):
|
||||||
@ -13,7 +13,7 @@ def example_mp(env_name, seed=1):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# While in this case gym.make() is possible to use as well, we recommend our custom make env function.
|
# While in this case gym.make() is possible to use as well, we recommend our custom make env function.
|
||||||
env = make_env(env_name, seed)
|
env = alr_envs.make(env_name, seed)
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
@ -29,13 +29,13 @@ def example_mp(env_name, seed=1):
|
|||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# DMP - not supported yet
|
# DMP - not supported yet
|
||||||
#example_mp("ReacherDetPMP-v2")
|
# example_mp("ReacherDMP-v2")
|
||||||
|
|
||||||
# DetProMP
|
# DetProMP
|
||||||
example_mp("ContinuousMountainCarDetPMP-v0")
|
example_mp("ContinuousMountainCarDetPMP-v0")
|
||||||
example_mp("ReacherDetPMP-v2")
|
example_mp("ReacherDetPMP-v2")
|
||||||
example_mp("FetchReachDenseDetPMP-v1")
|
example_mp("FetchReachDenseDetPMP-v1")
|
||||||
example_mp("FetchSlideDenseDetPMP-v1")
|
example_mp("FetchSlideDenseDetPMP-v1")
|
||||||
|
|
||||||
|
@ -1,30 +1,30 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from alr_envs import dmc
|
from alr_envs import dmc, meta
|
||||||
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||||
|
|
||||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||||
# for your environment are extracted below
|
# for your environment are extracted below
|
||||||
SEED = 10
|
SEED = 10
|
||||||
env_id = "cartpole-swingup"
|
env_id = "ball_in_cup-catch"
|
||||||
wrappers = [dmc.suite.cartpole.MPWrapper]
|
wrappers = [dmc.ball_in_cup.MPWrapper]
|
||||||
|
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 1,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 10,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 1,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 10,
|
"p_gains": 1,
|
||||||
"d_gains": 10 # a good starting point is the sqrt of p_gains
|
"d_gains": 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs = dict(time_limit=2, episode_length=200)
|
kwargs = dict(time_limit=2, episode_length=100)
|
||||||
|
|
||||||
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
@ -35,7 +35,6 @@ pos, vel = env.mp_rollout(env.action_space.sample())
|
|||||||
|
|
||||||
base_shape = env.full_action_space.shape
|
base_shape = env.full_action_space.shape
|
||||||
actual_pos = np.zeros((len(pos), *base_shape))
|
actual_pos = np.zeros((len(pos), *base_shape))
|
||||||
actual_pos_ball = np.zeros((len(pos), *base_shape))
|
|
||||||
actual_vel = np.zeros((len(pos), *base_shape))
|
actual_vel = np.zeros((len(pos), *base_shape))
|
||||||
act = np.zeros((len(pos), *base_shape))
|
act = np.zeros((len(pos), *base_shape))
|
||||||
|
|
||||||
@ -46,7 +45,6 @@ for t, pos_vel in enumerate(zip(pos, vel)):
|
|||||||
act[t, :] = actions
|
act[t, :] = actions
|
||||||
# TODO verify for your environment
|
# TODO verify for your environment
|
||||||
actual_pos[t, :] = env.current_pos
|
actual_pos[t, :] = env.current_pos
|
||||||
# actual_pos_ball[t, :] = env.physics.data.qpos[2:]
|
|
||||||
actual_vel[t, :] = env.current_vel
|
actual_vel[t, :] = env.current_vel
|
||||||
|
|
||||||
plt.figure(figsize=(15, 5))
|
plt.figure(figsize=(15, 5))
|
||||||
|
26
alr_envs/meta/README.MD
Normal file
26
alr_envs/meta/README.MD
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# MetaWorld Wrappers
|
||||||
|
|
||||||
|
These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Motion Primitive gym interface with them.
|
||||||
|
All Metaworld environments have a 39 dimensional observation space with the same structure. The tasks differ only in the objective and the initial observations that are randomized.
|
||||||
|
Unused observations are zeroed out. E.g. for `Button-Press-v2` the observation mask looks the following:
|
||||||
|
```python
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[False] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[True] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
```
|
||||||
|
For other tasks only the boolean values have to be adjusted accordingly.
|
1
alr_envs/meta/__init__.py
Normal file
1
alr_envs/meta/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from alr_envs.meta import goal_and_object_change, goal_and_endeffector_change, goal_change, object_change
|
68
alr_envs/meta/goal_and_endeffector_change.py
Normal file
68
alr_envs/meta/goal_and_endeffector_change.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ !=0 , !=0 , !=0 , 0. , 0.,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , !=0 , !=0 ,
|
||||||
|
!=0 , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , !=0 , !=0 , !=0])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# This structure is the same for all metaworld environments.
|
||||||
|
# Only the observations which change could differ
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[True] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[False] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten(), r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
68
alr_envs/meta/goal_and_object_change.py
Normal file
68
alr_envs/meta/goal_and_object_change.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ 0. , 0. , 0. , 0. , !=0,
|
||||||
|
!=0 , !=0 , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , !=0 , !=0 , !=0 ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , !=0 , !=0 , !=0])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# This structure is the same for all metaworld environments.
|
||||||
|
# Only the observations which change could differ
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[False] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[True] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten(), r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
68
alr_envs/meta/goal_change.py
Normal file
68
alr_envs/meta/goal_change.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ 0. , 0. , 0. , 0. , 0,
|
||||||
|
0 , 0 , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0 , 0 , 0 ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , !=0 , !=0 , !=0])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# This structure is the same for all metaworld environments.
|
||||||
|
# Only the observations which change could differ
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[False] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[False] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten(), r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
68
alr_envs/meta/object_change.py
Normal file
68
alr_envs/meta/object_change.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ 0. , 0. , 0. , 0. , !=0 ,
|
||||||
|
!=0 , !=0 , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0 , 0 , 0 ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0.])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# This structure is the same for all metaworld environments.
|
||||||
|
# Only the observations which change could differ
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[False] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[False] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten(), r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
3
alr_envs/open_ai/README.MD
Normal file
3
alr_envs/open_ai/README.MD
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# OpenAI Gym Wrappers
|
||||||
|
|
||||||
|
These are the Environment Wrappers for selected [OpenAI Gym](https://gym.openai.com/) environments in order to use our Motion Primitive gym interface with them.
|
@ -4,8 +4,10 @@ from typing import Union
|
|||||||
import gym
|
import gym
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
from alr_envs.utils.make_env_helpers import make
|
||||||
|
|
||||||
def make(
|
|
||||||
|
def make_dmc(
|
||||||
id: str,
|
id: str,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
visualize_reward: bool = True,
|
visualize_reward: bool = True,
|
||||||
|
@ -3,21 +3,22 @@ from typing import Iterable, List, Type, Union
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gym.envs.registration import EnvSpec
|
||||||
|
|
||||||
from mp_env_api import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||||
|
|
||||||
|
|
||||||
def make_env_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
TODO: Do we need this?
|
TODO: Do we need this?
|
||||||
Generate a callable to create a new gym environment with a given seed.
|
Generate a callable to create a new gym environment with a given seed.
|
||||||
The rank is added to the seed and can be used for example when using vector environments.
|
The rank is added to the seed and can be used for example when using vector environments.
|
||||||
E.g. [make_env_rank("my_env_name-v0", 123, i) for i in range(8)] creates a list of 8 environments
|
E.g. [make_rank("my_env_name-v0", 123, i) for i in range(8)] creates a list of 8 environments
|
||||||
with seeds 123 through 130.
|
with seeds 123 through 130.
|
||||||
Hence, testing environments should be seeded with a value which is offset by the number of training environments.
|
Hence, testing environments should be seeded with a value which is offset by the number of training environments.
|
||||||
Here e.g. [make_env_rank("my_env_name-v0", 123 + 8, i) for i in range(5)] for 5 testing environmetns
|
Here e.g. [make_rank("my_env_name-v0", 123 + 8, i) for i in range(5)] for 5 testing environmetns
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_id: name of the environment
|
env_id: name of the environment
|
||||||
@ -30,12 +31,12 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, *
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
return make_env(env_id, seed + rank, **kwargs)
|
return make(env_id, seed + rank, **kwargs)
|
||||||
|
|
||||||
return f if return_callable else f()
|
return f if return_callable else f()
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, seed, **kwargs):
|
def make(env_id: str, seed, **kwargs):
|
||||||
"""
|
"""
|
||||||
Converts an env_id to an environment with the gym API.
|
Converts an env_id to an environment with the gym API.
|
||||||
This also works for DeepMind Control Suite interface_wrappers
|
This also works for DeepMind Control Suite interface_wrappers
|
||||||
@ -58,9 +59,26 @@ def make_env(env_id: str, seed, **kwargs):
|
|||||||
env.action_space.seed(seed)
|
env.action_space.seed(seed)
|
||||||
env.observation_space.seed(seed)
|
env.observation_space.seed(seed)
|
||||||
except gym.error.Error:
|
except gym.error.Error:
|
||||||
|
|
||||||
|
# MetaWorld env
|
||||||
|
import metaworld
|
||||||
|
if env_id in metaworld.ML1.ENV_NAMES:
|
||||||
|
env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=seed, **kwargs)
|
||||||
|
# setting this avoids generating the same initialization after each reset
|
||||||
|
env._freeze_rand_vec = False
|
||||||
|
# Manually set spec, as metaworld environments are not registered via gym
|
||||||
|
env.unwrapped.spec = EnvSpec(env_id)
|
||||||
|
# Set Timelimit based on the maximum allowed path length of the environment
|
||||||
|
env = gym.wrappers.TimeLimit(env, max_episode_steps=env.max_path_length)
|
||||||
|
env.seed(seed)
|
||||||
|
env.action_space.seed(seed)
|
||||||
|
env.observation_space.seed(seed)
|
||||||
|
env.goal_space.seed(seed)
|
||||||
|
|
||||||
|
else:
|
||||||
# DMC
|
# DMC
|
||||||
from alr_envs.utils import make
|
from alr_envs.utils import make_dmc
|
||||||
env = make(env_id, seed=seed, **kwargs)
|
env = make_dmc(env_id, seed=seed, **kwargs)
|
||||||
|
|
||||||
assert env.base_step_limit == env.spec.max_episode_steps, \
|
assert env.base_step_limit == env.spec.max_episode_steps, \
|
||||||
f"The specified 'episode_length' of {env.spec.max_episode_steps} steps for gym is different from " \
|
f"The specified 'episode_length' of {env.spec.max_episode_steps} steps for gym is different from " \
|
||||||
@ -84,7 +102,7 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# _env = gym.make(env_id)
|
# _env = gym.make(env_id)
|
||||||
_env = make_env(env_id, seed, **kwargs)
|
_env = make(env_id, seed, **kwargs)
|
||||||
|
|
||||||
assert any(issubclass(w, MPEnvWrapper) for w in wrappers), \
|
assert any(issubclass(w, MPEnvWrapper) for w in wrappers), \
|
||||||
"At least one MPEnvWrapper is required in order to leverage motion primitive environments."
|
"At least one MPEnvWrapper is required in order to leverage motion primitive environments."
|
||||||
@ -175,7 +193,7 @@ def make_detpmp_env_helper(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def make_contextual_env(env_id, context, seed, rank):
|
def make_contextual_env(env_id, context, seed, rank):
|
||||||
env = make_env(env_id, seed + rank, context=context)
|
env = make(env_id, seed + rank, context=context)
|
||||||
# env = gym.make(env_id, context=context)
|
# env = gym.make(env_id, context=context)
|
||||||
# env.seed(seed + rank)
|
# env.seed(seed + rank)
|
||||||
return lambda: env
|
return lambda: env
|
||||||
|
@ -3,7 +3,7 @@ from gym.vector.async_vector_env import AsyncVectorEnv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from _collections import defaultdict
|
from _collections import defaultdict
|
||||||
|
|
||||||
from alr_envs.utils.make_env_helpers import make_env_rank
|
from alr_envs.utils.make_env_helpers import make_rank
|
||||||
|
|
||||||
|
|
||||||
def split_array(ary, size):
|
def split_array(ary, size):
|
||||||
@ -54,7 +54,7 @@ class AlrMpEnvSampler:
|
|||||||
|
|
||||||
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.env = AsyncVectorEnv([make_env_rank(env_id, seed, i, **env_kwargs) for i in range(num_envs)])
|
self.env = AsyncVectorEnv([make_rank(env_id, seed, i, **env_kwargs) for i in range(num_envs)])
|
||||||
|
|
||||||
def __call__(self, params):
|
def __call__(self, params):
|
||||||
params = np.atleast_2d(params)
|
params = np.atleast_2d(params)
|
||||||
|
1
setup.py
1
setup.py
@ -12,6 +12,7 @@ setup(
|
|||||||
'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git',
|
'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git',
|
||||||
'mujoco-py<2.1,>=2.0',
|
'mujoco-py<2.1,>=2.0',
|
||||||
'dm_control'
|
'dm_control'
|
||||||
|
'metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld'
|
||||||
],
|
],
|
||||||
|
|
||||||
url='https://github.com/ALRhub/alr_envs/',
|
url='https://github.com/ALRhub/alr_envs/',
|
||||||
|
127
test/test_dmc_envs.py
Normal file
127
test/test_dmc_envs.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from dm_control import suite, manipulation
|
||||||
|
|
||||||
|
from alr_envs import make
|
||||||
|
|
||||||
|
DMC_ENVS = [f'{env}-{task}' for env, task in suite.ALL_TASKS if env != "lqr"]
|
||||||
|
MANIPULATION_SPECS = [f'manipulation-{task}' for task in manipulation.ALL if task.endswith('_features')]
|
||||||
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvironments(unittest.TestCase):
|
||||||
|
|
||||||
|
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
||||||
|
"""
|
||||||
|
Example for running a DMC based env in the step based setting.
|
||||||
|
The env_id has to be specified as `domain_name-task_name` or
|
||||||
|
for manipulation tasks as `manipulation-environment_name`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_id: Either `domain_name-task_name` or `manipulation-environment_name`
|
||||||
|
iterations: Number of rollout steps to run
|
||||||
|
seed= random seeding
|
||||||
|
render: Render the episode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
env: gym.Env = make(env_id, seed=seed)
|
||||||
|
rewards = []
|
||||||
|
observations = []
|
||||||
|
dones = []
|
||||||
|
obs = env.reset()
|
||||||
|
self._verify_observations(obs, env.observation_space, "reset()")
|
||||||
|
|
||||||
|
length = env.spec.max_episode_steps
|
||||||
|
if iterations is None:
|
||||||
|
if length is None:
|
||||||
|
iterations = 1
|
||||||
|
else:
|
||||||
|
iterations = length
|
||||||
|
|
||||||
|
# number of samples(multiple environment steps)
|
||||||
|
for i in range(iterations):
|
||||||
|
observations.append(obs)
|
||||||
|
|
||||||
|
ac = env.action_space.sample()
|
||||||
|
# ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape)
|
||||||
|
obs, reward, done, info = env.step(ac)
|
||||||
|
|
||||||
|
self._verify_observations(obs, env.observation_space, "step()")
|
||||||
|
self._verify_reward(reward)
|
||||||
|
self._verify_done(done)
|
||||||
|
|
||||||
|
rewards.append(reward)
|
||||||
|
dones.append(done)
|
||||||
|
|
||||||
|
if render:
|
||||||
|
env.render("human")
|
||||||
|
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
assert done, "Done flag is not True after max episode length."
|
||||||
|
observations.append(obs)
|
||||||
|
env.close()
|
||||||
|
del env
|
||||||
|
return np.array(observations), np.array(rewards), np.array(dones)
|
||||||
|
|
||||||
|
def _verify_observations(self, obs, observation_space, obs_type="reset()"):
|
||||||
|
self.assertTrue(observation_space.contains(obs),
|
||||||
|
f"Observation {obs} received from {obs_type} "
|
||||||
|
f"not contained in observation space {observation_space}.")
|
||||||
|
|
||||||
|
def _verify_reward(self, reward):
|
||||||
|
self.assertIsInstance(reward, float, f"Returned {reward} as reward, expected float.")
|
||||||
|
|
||||||
|
def _verify_done(self, done):
|
||||||
|
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
|
||||||
|
|
||||||
|
def test_dmc_functionality(self):
|
||||||
|
"""Tests that environments runs without errors using random actions."""
|
||||||
|
for env_id in DMC_ENVS:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
def test_dmc_determinism(self):
|
||||||
|
"""Tests that identical seeds produce identical trajectories."""
|
||||||
|
seed = 0
|
||||||
|
# Iterate over two trajectories, which should have the same state and action sequence
|
||||||
|
for env_id in DMC_ENVS:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
traj1 = self._run_env(env_id, seed=seed)
|
||||||
|
traj2 = self._run_env(env_id, seed=seed)
|
||||||
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
|
obs1, rwd1, done1, obs2, rwd2, done2 = time_step
|
||||||
|
self.assertTrue(np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.")
|
||||||
|
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
||||||
|
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
||||||
|
|
||||||
|
def test_manipulation_functionality(self):
|
||||||
|
"""Tests that environments runs without errors using random actions."""
|
||||||
|
for env_id in MANIPULATION_SPECS:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
def test_manipulation_determinism(self):
|
||||||
|
"""Tests that identical seeds produce identical trajectories."""
|
||||||
|
seed = 0
|
||||||
|
# Iterate over two trajectories, which should have the same state and action sequence
|
||||||
|
for env_id in MANIPULATION_SPECS:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
traj1 = self._run_env(env_id, seed=seed)
|
||||||
|
traj2 = self._run_env(env_id, seed=seed)
|
||||||
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
|
obs1, rwd1, done1, obs2, rwd2, done2 = time_step
|
||||||
|
self.assertTrue(np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.")
|
||||||
|
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
||||||
|
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
||||||
|
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -4,7 +4,7 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import alr_envs # noqa
|
import alr_envs # noqa
|
||||||
from alr_envs.utils.make_env_helpers import make_env
|
from alr_envs.utils.make_env_helpers import make
|
||||||
|
|
||||||
ALL_SPECS = list(spec for spec in gym.envs.registry.all() if "alr_envs" in spec.entry_point)
|
ALL_SPECS = list(spec for spec in gym.envs.registry.all() if "alr_envs" in spec.entry_point)
|
||||||
SEED = 1
|
SEED = 1
|
||||||
@ -27,7 +27,7 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env: gym.Env = make_env(env_id, seed=seed)
|
env: gym.Env = make(env_id, seed=seed)
|
||||||
rewards = []
|
rewards = []
|
||||||
observations = []
|
observations = []
|
||||||
dones = []
|
dones = []
|
||||||
@ -62,6 +62,7 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
if done:
|
if done:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
assert done, "Done flag is not True after max episode length."
|
||||||
observations.append(obs)
|
observations.append(obs)
|
||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
@ -81,7 +82,6 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
def test_environment_functionality(self):
|
def test_environment_functionality(self):
|
||||||
"""Tests that environments runs without errors using random actions."""
|
"""Tests that environments runs without errors using random actions."""
|
||||||
for spec in ALL_SPECS:
|
for spec in ALL_SPECS:
|
||||||
# try:
|
|
||||||
with self.subTest(msg=spec.id):
|
with self.subTest(msg=spec.id):
|
||||||
self._run_env(spec.id)
|
self._run_env(spec.id)
|
||||||
|
|
||||||
@ -91,7 +91,6 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
# Iterate over two trajectories, which should have the same state and action sequence
|
# Iterate over two trajectories, which should have the same state and action sequence
|
||||||
for spec in ALL_SPECS:
|
for spec in ALL_SPECS:
|
||||||
with self.subTest(msg=spec.id):
|
with self.subTest(msg=spec.id):
|
||||||
self._run_env(spec.id)
|
|
||||||
traj1 = self._run_env(spec.id, seed=seed)
|
traj1 = self._run_env(spec.id, seed=seed)
|
||||||
traj2 = self._run_env(spec.id, seed=seed)
|
traj2 = self._run_env(spec.id, seed=seed)
|
||||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
|
107
test/test_metaworld_envs.py
Normal file
107
test/test_metaworld_envs.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from alr_envs import make
|
||||||
|
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
||||||
|
|
||||||
|
ALL_ENVS = [env.split("-goal-observable")[0] for env, _ in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
||||||
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvironments(unittest.TestCase):
|
||||||
|
|
||||||
|
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
||||||
|
"""
|
||||||
|
Example for running a DMC based env in the step based setting.
|
||||||
|
The env_id has to be specified as `domain_name-task_name` or
|
||||||
|
for manipulation tasks as `manipulation-environment_name`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_id: Either `domain_name-task_name` or `manipulation-environment_name`
|
||||||
|
iterations: Number of rollout steps to run
|
||||||
|
seed= random seeding
|
||||||
|
render: Render the episode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
env: gym.Env = make(env_id, seed=seed)
|
||||||
|
rewards = []
|
||||||
|
observations = []
|
||||||
|
actions = []
|
||||||
|
dones = []
|
||||||
|
obs = env.reset()
|
||||||
|
self._verify_observations(obs, env.observation_space, "reset()")
|
||||||
|
|
||||||
|
length = env.max_path_length
|
||||||
|
if iterations is None:
|
||||||
|
if length is None:
|
||||||
|
iterations = 1
|
||||||
|
else:
|
||||||
|
iterations = length
|
||||||
|
|
||||||
|
# number of samples(multiple environment steps)
|
||||||
|
for i in range(iterations):
|
||||||
|
observations.append(obs)
|
||||||
|
|
||||||
|
ac = env.action_space.sample()
|
||||||
|
actions.append(ac)
|
||||||
|
# ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape)
|
||||||
|
obs, reward, done, info = env.step(ac)
|
||||||
|
|
||||||
|
self._verify_observations(obs, env.observation_space, "step()")
|
||||||
|
self._verify_reward(reward)
|
||||||
|
self._verify_done(done)
|
||||||
|
|
||||||
|
rewards.append(reward)
|
||||||
|
dones.append(done)
|
||||||
|
|
||||||
|
if render:
|
||||||
|
env.render("human")
|
||||||
|
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
assert done, "Done flag is not True after max episode length."
|
||||||
|
observations.append(obs)
|
||||||
|
env.close()
|
||||||
|
del env
|
||||||
|
return np.array(observations), np.array(rewards), np.array(dones), np.array(actions)
|
||||||
|
|
||||||
|
def _verify_observations(self, obs, observation_space, obs_type="reset()"):
|
||||||
|
self.assertTrue(observation_space.contains(obs),
|
||||||
|
f"Observation {obs} received from {obs_type} "
|
||||||
|
f"not contained in observation space {observation_space}.")
|
||||||
|
|
||||||
|
def _verify_reward(self, reward):
|
||||||
|
self.assertIsInstance(reward, float, f"Returned {reward} as reward, expected float.")
|
||||||
|
|
||||||
|
def _verify_done(self, done):
|
||||||
|
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
|
||||||
|
|
||||||
|
def test_dmc_functionality(self):
|
||||||
|
"""Tests that environments runs without errors using random actions."""
|
||||||
|
for env_id in ALL_ENVS:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
def test_dmc_determinism(self):
|
||||||
|
"""Tests that identical seeds produce identical trajectories."""
|
||||||
|
seed = 0
|
||||||
|
# Iterate over two trajectories, which should have the same state and action sequence
|
||||||
|
for env_id in ALL_ENVS:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
traj1 = self._run_env(env_id, seed=seed)
|
||||||
|
traj2 = self._run_env(env_id, seed=seed)
|
||||||
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
|
obs1, rwd1, done1, ac1, obs2, rwd2, done2, ac2 = time_step
|
||||||
|
self.assertTrue(np.array_equal(ac1, ac2), f"Actions [{i}] delta {ac1 - ac2} is not zero.")
|
||||||
|
self.assertTrue(np.array_equal(obs1, obs2), f"Observations [{i}] delta {obs1 - obs2} is not zero.")
|
||||||
|
self.assertAlmostEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
||||||
|
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user