2022-06-29 09:37:18 +02:00
|
|
|
from typing import Union, Tuple
|
2022-07-01 11:42:42 +02:00
|
|
|
from mp_pytorch.mp.mp_interfaces import MPInterface
|
|
|
|
from abc import abstractmethod
|
2022-06-29 09:37:18 +02:00
|
|
|
|
|
|
|
import gym
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
class RawInterfaceWrapper(gym.Wrapper):
|
|
|
|
|
|
|
|
@property
|
|
|
|
@abstractmethod
|
|
|
|
def context_mask(self) -> np.ndarray:
|
|
|
|
"""
|
|
|
|
This function defines the contexts. The contexts are defined as specific observations.
|
|
|
|
Returns:
|
|
|
|
bool array representing the indices of the observations
|
|
|
|
|
|
|
|
"""
|
|
|
|
return np.ones(self.env.observation_space.shape[0], dtype=bool)
|
|
|
|
|
|
|
|
@property
|
|
|
|
@abstractmethod
|
|
|
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
|
|
"""
|
|
|
|
Returns the current position of the action/control dimension.
|
|
|
|
The dimensionality has to match the action/control dimension.
|
|
|
|
This is not required when exclusively using velocity control,
|
|
|
|
it should, however, be implemented regardless.
|
|
|
|
E.g. The joint positions that are directly or indirectly controlled by the action.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@property
|
|
|
|
@abstractmethod
|
|
|
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
|
|
"""
|
|
|
|
Returns the current velocity of the action/control dimension.
|
|
|
|
The dimensionality has to match the action/control dimension.
|
|
|
|
This is not required when exclusively using position control,
|
|
|
|
it should, however, be implemented regardless.
|
|
|
|
E.g. The joint velocities that are directly or indirectly controlled by the action.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def dt(self) -> float:
|
|
|
|
"""
|
|
|
|
Control frequency of the environment
|
|
|
|
Returns: float
|
|
|
|
|
|
|
|
"""
|
2022-06-29 12:25:40 +02:00
|
|
|
return self.env.dt
|
2022-06-29 09:37:18 +02:00
|
|
|
|
|
|
|
def do_replanning(self, pos, vel, s, a, t):
|
|
|
|
# return t % 100 == 0
|
|
|
|
# return bool(self.replanning_model(s))
|
|
|
|
return False
|
|
|
|
|
2022-07-06 09:05:35 +02:00
|
|
|
def _episode_callback(self, action: np.ndarray, traj_gen: MPInterface) -> Tuple[
|
|
|
|
np.ndarray, Union[np.ndarray, None]]:
|
2022-06-29 09:37:18 +02:00
|
|
|
"""
|
|
|
|
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
|
|
|
include other actions like ball releasing time for the beer pong environment.
|
|
|
|
This only needs to be overwritten if the action space is modified.
|
|
|
|
Args:
|
2022-06-30 14:08:54 +02:00
|
|
|
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
|
|
|
|
specified, else only traj_gen parameters
|
2022-06-29 09:37:18 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple: mp_arguments and other arguments
|
|
|
|
"""
|
|
|
|
return action, None
|