49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
|
from typing import Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from mp_env_api import MPEnvWrapper
|
||
|
|
||
|
|
||
|
class MPWrapper(MPEnvWrapper):
|
||
|
|
||
|
@property
|
||
|
def active_obs(self):
|
||
|
# This structure is the same for all metaworld environments.
|
||
|
# Only the observations which change could differ
|
||
|
return np.hstack([
|
||
|
# Current observation
|
||
|
[False] * 3, # end-effector position
|
||
|
[False] * 1, # normalized gripper open distance
|
||
|
[True] * 3, # main object position
|
||
|
[False] * 4, # main object quaternion
|
||
|
[False] * 3, # secondary object position
|
||
|
[False] * 4, # secondary object quaternion
|
||
|
# Previous observation
|
||
|
# TODO: Include previous values? According to their source they might be wrong for the first iteration.
|
||
|
[False] * 3, # previous end-effector position
|
||
|
[False] * 1, # previous normalized gripper open distance
|
||
|
[False] * 3, # previous main object position
|
||
|
[False] * 4, # previous main object quaternion
|
||
|
[False] * 3, # previous second object position
|
||
|
[False] * 4, # previous second object quaternion
|
||
|
# Goal
|
||
|
[True] * 3, # goal position
|
||
|
])
|
||
|
|
||
|
@property
|
||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||
|
return self.env.physics.named.data.qpos[:]
|
||
|
|
||
|
@property
|
||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||
|
return self.env.physics.named.data.qvel[:]
|
||
|
|
||
|
@property
|
||
|
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||
|
|
||
|
@property
|
||
|
def dt(self) -> Union[float, int]:
|
||
|
return self.env.dt
|