added all metaworld tasks as ProMP
This commit is contained in:
parent
cb603859d9
commit
c39877ece0
@ -316,7 +316,6 @@ register(
|
|||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
"bandwidth_factor": 2,
|
"bandwidth_factor": 2,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 50,
|
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 50,
|
||||||
@ -340,7 +339,6 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 50,
|
||||||
@ -828,6 +826,7 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 1.,
|
"p_gains": 1.,
|
||||||
@ -849,6 +848,7 @@ register(
|
|||||||
"duration": 1,
|
"duration": 1,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": .6,
|
"p_gains": .6,
|
||||||
@ -870,6 +870,7 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -887,6 +888,7 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -904,6 +906,7 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -921,25 +924,109 @@ register(
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
"width": 0.02,
|
||||||
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MetaWorld
|
||||||
|
|
||||||
|
goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||||
|
]
|
||||||
|
for env_id in goal_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
register(
|
register(
|
||||||
id='ButtonPressDetPMP-v2',
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "button-press-v2",
|
"name": env_id,
|
||||||
"wrappers": [meta.button_press.MPWrapper],
|
"wrappers": [meta.goal_change.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "position"
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||||
|
for env_id in object_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
|
register(
|
||||||
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": env_id,
|
||||||
|
"wrappers": [meta.object_change.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||||
|
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||||
|
"coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2",
|
||||||
|
"door-lock-v2", "door-open-v2", "door-unlock-v2", "hand-insert-v2",
|
||||||
|
"drawer-close-v2", "drawer-open-v2", "faucet-open-v2", "faucet-close-v2",
|
||||||
|
"handle-press-side-v2", "handle-press-v2", "handle-pull-side-v2",
|
||||||
|
"handle-pull-v2", "lever-pull-v2", "peg-insert-side-v2", "pick-place-wall-v2",
|
||||||
|
"reach-v2", "push-back-v2", "push-v2", "pick-place-v2", "peg-unplug-side-v2",
|
||||||
|
"soccer-v2", "stick-push-v2", "stick-pull-v2", "push-wall-v2", "reach-wall-v2",
|
||||||
|
"shelf-place-v2", "sweep-v2", "window-open-v2", "window-close-v2"
|
||||||
|
]
|
||||||
|
for env_id in goal_and_object_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
|
register(
|
||||||
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": env_id,
|
||||||
|
"wrappers": [meta.goal_and_object_change.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||||
|
for env_id in goal_and_endeffector_change_envs:
|
||||||
|
env_id_split = env_id.split("-")
|
||||||
|
name = "".join([s.capitalize() for s in env_id_split[:-1]])
|
||||||
|
register(
|
||||||
|
id=f'{name}DetPMP-{env_id_split[-1]}',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": env_id,
|
||||||
|
"wrappers": [meta.goal_and_endeffector_change.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 4,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 6.25,
|
||||||
|
"post_traj_time": 0,
|
||||||
|
"width": 0.025,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_type": "metaworld",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@ -80,20 +80,21 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16
|
|||||||
rewards[done] = 0
|
rewards[done] = 0
|
||||||
|
|
||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return *map(lambda v: np.stack(v)[:n_samples], buffer.values()),
|
return (*map(lambda v: np.stack(v)[:n_samples], buffer.values()),)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = False
|
render = True
|
||||||
|
|
||||||
# Basic gym task
|
# Basic gym task
|
||||||
example_general("Pendulum-v0", seed=10, iterations=200, render=render)
|
example_general("Pendulum-v0", seed=10, iterations=200, render=render)
|
||||||
#
|
|
||||||
# # Basis task from framework
|
# # Basis task from framework
|
||||||
example_general("alr_envs:HoleReacher-v0", seed=10, iterations=200, render=render)
|
example_general("alr_envs:HoleReacher-v0", seed=10, iterations=200, render=render)
|
||||||
#
|
|
||||||
# # OpenAI Mujoco task
|
# # OpenAI Mujoco task
|
||||||
example_general("HalfCheetah-v2", seed=10, render=render)
|
example_general("HalfCheetah-v2", seed=10, render=render)
|
||||||
#
|
|
||||||
# # Mujoco task from framework
|
# # Mujoco task from framework
|
||||||
example_general("alr_envs:ALRReacher-v0", seed=10, iterations=200, render=render)
|
example_general("alr_envs:ALRReacher-v0", seed=10, iterations=200, render=render)
|
||||||
|
|
||||||
|
@ -1,30 +1,30 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from alr_envs import dmc
|
from alr_envs import dmc, meta
|
||||||
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||||
|
|
||||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||||
# for your environment are extracted below
|
# for your environment are extracted below
|
||||||
SEED = 10
|
SEED = 10
|
||||||
env_id = "cartpole-swingup"
|
env_id = "ball_in_cup-catch"
|
||||||
wrappers = [dmc.suite.cartpole.MPWrapper]
|
wrappers = [dmc.ball_in_cup.MPWrapper]
|
||||||
|
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 1,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 10,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 1,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 10,
|
"p_gains": 1,
|
||||||
"d_gains": 10 # a good starting point is the sqrt of p_gains
|
"d_gains": 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs = dict(time_limit=2, episode_length=200)
|
kwargs = dict(time_limit=2, episode_length=100)
|
||||||
|
|
||||||
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
@ -35,7 +35,6 @@ pos, vel = env.mp_rollout(env.action_space.sample())
|
|||||||
|
|
||||||
base_shape = env.full_action_space.shape
|
base_shape = env.full_action_space.shape
|
||||||
actual_pos = np.zeros((len(pos), *base_shape))
|
actual_pos = np.zeros((len(pos), *base_shape))
|
||||||
actual_pos_ball = np.zeros((len(pos), *base_shape))
|
|
||||||
actual_vel = np.zeros((len(pos), *base_shape))
|
actual_vel = np.zeros((len(pos), *base_shape))
|
||||||
act = np.zeros((len(pos), *base_shape))
|
act = np.zeros((len(pos), *base_shape))
|
||||||
|
|
||||||
@ -46,7 +45,6 @@ for t, pos_vel in enumerate(zip(pos, vel)):
|
|||||||
act[t, :] = actions
|
act[t, :] = actions
|
||||||
# TODO verify for your environment
|
# TODO verify for your environment
|
||||||
actual_pos[t, :] = env.current_pos
|
actual_pos[t, :] = env.current_pos
|
||||||
# actual_pos_ball[t, :] = env.physics.data.qpos[2:]
|
|
||||||
actual_vel[t, :] = env.current_vel
|
actual_vel[t, :] = env.current_vel
|
||||||
|
|
||||||
plt.figure(figsize=(15, 5))
|
plt.figure(figsize=(15, 5))
|
||||||
|
@ -1 +1 @@
|
|||||||
from alr_envs.meta import button_press
|
from alr_envs.meta import goal_and_object_change, goal_and_endeffector_change, goal_change, object_change
|
||||||
|
68
alr_envs/meta/goal_and_endeffector_change.py
Normal file
68
alr_envs/meta/goal_and_endeffector_change.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ !=0 , !=0 , !=0 , 0. , 0.,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , !=0 , !=0 ,
|
||||||
|
!=0 , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , !=0 , !=0 , !=0])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# This structure is the same for all metaworld environments.
|
||||||
|
# Only the observations which change could differ
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[True] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[False] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten(), r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
@ -6,6 +6,25 @@ from mp_env_api import MPEnvWrapper
|
|||||||
|
|
||||||
|
|
||||||
class MPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ 0. , 0. , 0. , 0. , !=0,
|
||||||
|
!=0 , !=0 , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , !=0 , !=0 , !=0 ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , !=0 , !=0 , !=0])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
68
alr_envs/meta/goal_change.py
Normal file
68
alr_envs/meta/goal_change.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ 0. , 0. , 0. , 0. , 0,
|
||||||
|
0 , 0 , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0 , 0 , 0 ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , !=0 , !=0 , !=0])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# This structure is the same for all metaworld environments.
|
||||||
|
# Only the observations which change could differ
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[False] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[False] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten(), r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
68
alr_envs/meta/object_change.py
Normal file
68
alr_envs/meta/object_change.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
"""
|
||||||
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
|
You can verify this by executing the code below for your environment id and check if the output is non-zero
|
||||||
|
at the same indices.
|
||||||
|
```python
|
||||||
|
import alr_envs
|
||||||
|
env = alr_envs.make(env_id, 1)
|
||||||
|
print(env.reset() - env.reset())
|
||||||
|
array([ 0. , 0. , 0. , 0. , !=0 ,
|
||||||
|
!=0 , !=0 , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0 , 0 , 0 ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0. , 0. ,
|
||||||
|
0. , 0. , 0. , 0.])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
# This structure is the same for all metaworld environments.
|
||||||
|
# Only the observations which change could differ
|
||||||
|
return np.hstack([
|
||||||
|
# Current observation
|
||||||
|
[False] * 3, # end-effector position
|
||||||
|
[False] * 1, # normalized gripper open distance
|
||||||
|
[False] * 3, # main object position
|
||||||
|
[False] * 4, # main object quaternion
|
||||||
|
[False] * 3, # secondary object position
|
||||||
|
[False] * 4, # secondary object quaternion
|
||||||
|
# Previous observation
|
||||||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||||||
|
[False] * 3, # previous end-effector position
|
||||||
|
[False] * 1, # previous normalized gripper open distance
|
||||||
|
[False] * 3, # previous main object position
|
||||||
|
[False] * 4, # previous main object quaternion
|
||||||
|
[False] * 3, # previous second object position
|
||||||
|
[False] * 4, # previous second object quaternion
|
||||||
|
# Goal
|
||||||
|
[True] * 3, # goal position
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten(), r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
Loading…
Reference in New Issue
Block a user