updates
This commit is contained in:
parent
7b2451d317
commit
5a547d85f9
@ -23,7 +23,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
replanning_schedule: Optional[
|
replanning_schedule: Optional[
|
||||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||||
max_planning_times: int = 1,
|
max_planning_times = None,
|
||||||
desired_conditioning: bool = False
|
desired_conditioning: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -163,8 +163,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
# action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
|
# action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
|
||||||
|
|
||||||
# TODO remove this part, right now only needed for beer pong
|
# TODO remove this part, right now only needed for beer pong
|
||||||
mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen)
|
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
|
||||||
position, velocity = self.get_trajectory(mp_params)
|
position, velocity = self.get_trajectory(action)
|
||||||
|
traj_is_valid = self.env.episode_callback(action, position, velocity)
|
||||||
|
|
||||||
trajectory_length = len(position)
|
trajectory_length = len(position)
|
||||||
rewards = np.zeros(shape=(trajectory_length,))
|
rewards = np.zeros(shape=(trajectory_length,))
|
||||||
@ -176,6 +177,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
|
if self.verbose >= 2:
|
||||||
|
desired_pos_traj = []
|
||||||
|
desired_vel_traj = []
|
||||||
|
pos_traj = []
|
||||||
|
vel_traj = []
|
||||||
|
|
||||||
|
if traj_is_valid:
|
||||||
self.plan_counts += 1
|
self.plan_counts += 1
|
||||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||||
@ -192,14 +200,20 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
elems[t] = v
|
elems[t] = v
|
||||||
infos[k] = elems
|
infos[k] = elems
|
||||||
|
|
||||||
|
if self.verbose >= 2:
|
||||||
|
desired_pos_traj.append(pos)
|
||||||
|
desired_vel_traj.append(vel)
|
||||||
|
pos_traj.append(self.current_pos)
|
||||||
|
vel_traj.append(self.current_vel)
|
||||||
|
|
||||||
if self.render_kwargs:
|
if self.render_kwargs:
|
||||||
self.env.render(**self.render_kwargs)
|
self.env.render(**self.render_kwargs)
|
||||||
|
|
||||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||||
t + 1 + self.current_traj_steps):
|
t + 1 + self.current_traj_steps):
|
||||||
|
|
||||||
if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
|
# if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
self.condition_pos = pos if self.desired_conditioning else self.current_pos
|
self.condition_pos = pos if self.desired_conditioning else self.current_pos
|
||||||
self.condition_vel = vel if self.desired_conditioning else self.current_vel
|
self.condition_vel = vel if self.desired_conditioning else self.current_vel
|
||||||
@ -217,11 +231,17 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos['step_actions'] = actions[:t + 1]
|
infos['step_actions'] = actions[:t + 1]
|
||||||
infos['step_observations'] = observations[:t + 1]
|
infos['step_observations'] = observations[:t + 1]
|
||||||
infos['step_rewards'] = rewards[:t + 1]
|
infos['step_rewards'] = rewards[:t + 1]
|
||||||
|
infos['desired_pos_traj'] = np.array(desired_pos_traj)
|
||||||
|
infos['desired_vel_traj'] = np.array(desired_vel_traj)
|
||||||
|
infos['pos_traj'] = np.array(pos_traj)
|
||||||
|
infos['vel_traj'] = np.array(vel_traj)
|
||||||
|
|
||||||
infos['trajectory_length'] = t + 1
|
infos['trajectory_length'] = t + 1
|
||||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||||
return self.observation(obs), trajectory_return, done, infos
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
|
else:
|
||||||
|
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
|
||||||
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
"""Only set render options here, such that they can be used during the rollout.
|
||||||
This only needs to be called once"""
|
This only needs to be called once"""
|
||||||
|
@ -52,8 +52,7 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
"""
|
"""
|
||||||
return self.env.dt
|
return self.env.dt
|
||||||
|
|
||||||
def episode_callback(self, action: np.ndarray, traj_gen: MPInterface) -> Tuple[
|
def episode_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.array) -> Tuple[bool]:
|
||||||
np.ndarray, Union[np.ndarray, None]]:
|
|
||||||
"""
|
"""
|
||||||
Used to extract the parameters for the movement primitive and other parameters from an action array which might
|
Used to extract the parameters for the movement primitive and other parameters from an action array which might
|
||||||
include other actions like ball releasing time for the beer pong environment.
|
include other actions like ball releasing time for the beer pong environment.
|
||||||
@ -65,4 +64,11 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple: mp_arguments and other arguments
|
Tuple: mp_arguments and other arguments
|
||||||
"""
|
"""
|
||||||
return action, None
|
return True
|
||||||
|
|
||||||
|
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
|
"""
|
||||||
|
Used to return a fake return from the environment if the desired trajectory is invalid.
|
||||||
|
"""
|
||||||
|
obs = np.zeros(1)
|
||||||
|
return obs, 0, True, {}
|
@ -28,7 +28,9 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
'trajectory_generator_type': 'promp'
|
'trajectory_generator_type': 'promp'
|
||||||
},
|
},
|
||||||
"phase_generator_kwargs": {
|
"phase_generator_kwargs": {
|
||||||
'phase_generator_type': 'linear'
|
'phase_generator_type': 'linear',
|
||||||
|
'learn_tau': False,
|
||||||
|
'learn_delay': False,
|
||||||
},
|
},
|
||||||
"controller_kwargs": {
|
"controller_kwargs": {
|
||||||
'controller_type': 'motor',
|
'controller_type': 'motor',
|
||||||
@ -40,6 +42,8 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
'num_basis': 5,
|
'num_basis': 5,
|
||||||
'num_basis_zero_start': 1,
|
'num_basis_zero_start': 1,
|
||||||
'basis_bandwidth_factor': 3.0,
|
'basis_bandwidth_factor': 3.0,
|
||||||
|
},
|
||||||
|
"black_box_kwargs": {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,6 +249,18 @@ register(
|
|||||||
max_episode_steps=FIXED_RELEASE_STEP,
|
max_episode_steps=FIXED_RELEASE_STEP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Table Tennis environments
|
||||||
|
for ctxt_dim in [2, 4]:
|
||||||
|
register(
|
||||||
|
id='TableTennis{}D-v0'.format(ctxt_dim),
|
||||||
|
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
|
||||||
|
max_episode_steps=350,
|
||||||
|
kwargs={
|
||||||
|
"ctxt_dim": ctxt_dim,
|
||||||
|
'frame_skip': 4
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# movement Primitive Environments
|
# movement Primitive Environments
|
||||||
|
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
@ -515,6 +531,29 @@ for _v in _versions:
|
|||||||
kwargs=kwargs_dict_box_pushing_prodmp
|
kwargs=kwargs_dict_box_pushing_prodmp
|
||||||
)
|
)
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
|
## Table Tennis
|
||||||
|
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0']
|
||||||
|
for _v in _versions:
|
||||||
|
_name = _v.split("-")
|
||||||
|
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
|
kwargs_dict_tt_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
|
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.MPWrapper)
|
||||||
|
kwargs_dict_tt_promp['name'] = _v
|
||||||
|
kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
||||||
|
kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||||
|
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True
|
||||||
|
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
|
||||||
|
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
|
||||||
|
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
|
||||||
|
kwargs_dict_tt_promp['black_box_kwargs']['duration'] = 2.
|
||||||
|
kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_tt_promp
|
||||||
|
)
|
||||||
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
#
|
#
|
||||||
# ## Walker2DJump
|
# ## Walker2DJump
|
||||||
# _versions = ['Walker2DJump-v0']
|
# _versions = ['Walker2DJump-v0']
|
||||||
|
@ -8,3 +8,4 @@ from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
|
|||||||
from .reacher.reacher import ReacherEnv
|
from .reacher.reacher import ReacherEnv
|
||||||
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
|
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
|
||||||
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
|
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
|
||||||
|
from .table_tennis.table_tennis_env import TableTennisEnv
|
||||||
|
@ -28,10 +28,10 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
return self.data.qvel[0:7].copy()
|
return self.data.qvel[0:7].copy()
|
||||||
|
|
||||||
# TODO: Fix this
|
# TODO: Fix this
|
||||||
def episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
def episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None], bool]:
|
||||||
if mp.learn_tau:
|
if mp.learn_tau:
|
||||||
self.release_step = action[0] / self.dt # Tau value
|
self.release_step = action[0] / self.dt # Tau value
|
||||||
return action, None
|
return action, None, True
|
||||||
|
|
||||||
def set_context(self, context):
|
def set_context(self, context):
|
||||||
xyz = np.zeros(3)
|
xyz = np.zeros(3)
|
||||||
|
@ -3,6 +3,7 @@ from typing import Union, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(RawInterfaceWrapper):
|
||||||
@ -13,10 +14,8 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * 7, # joints position
|
[False] * 7, # joints position
|
||||||
[False] * 7, # joints velocity
|
[False] * 7, # joints velocity
|
||||||
[False] * 3, # position of box
|
[False] * 3, # position ball
|
||||||
[False] * 4, # orientation of box
|
[True] * 2, # target landing position
|
||||||
[True] * 3, # position of target
|
|
||||||
[True] * 4, # orientation of target
|
|
||||||
# [True] * 1, # time
|
# [True] * 1, # time
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -27,3 +26,27 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
@property
|
@property
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.data.qvel[:7].copy()
|
return self.data.qvel[:7].copy()
|
||||||
|
|
||||||
|
def episode_callback(self, action, pos_traj, vel_traj):
|
||||||
|
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||||
|
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
||||||
|
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
||||||
|
-> Tuple[np.ndarray, float, bool, dict]:
|
||||||
|
tau_invalid_penalty = np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])
|
||||||
|
delay_invalid_penalty = np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])
|
||||||
|
violate_high_bound_error = np.sum(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||||
|
violate_low_bound_error = np.sum(np.maximum(jnt_pos_low - pos_traj, 0))
|
||||||
|
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||||
|
violate_high_bound_error + violate_low_bound_error
|
||||||
|
return self.get_obs(), -invalid_penalty, True, {
|
||||||
|
"hit_ball": [False],
|
||||||
|
"ball_returned_success": [False],
|
||||||
|
"land_dist_error": [10.],
|
||||||
|
"is_success": [False],
|
||||||
|
'trajectory_length': 1,
|
||||||
|
"num_steps": [1]
|
||||||
|
}
|
@ -4,6 +4,8 @@ import numpy as np
|
|||||||
from gym import utils, spaces
|
from gym import utils, spaces
|
||||||
from gym.envs.mujoco import MujocoEnv
|
from gym.envs.mujoco import MujocoEnv
|
||||||
|
|
||||||
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
|
||||||
|
|
||||||
import mujoco
|
import mujoco
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_TABLE_TENNIS = 250
|
MAX_EPISODE_STEPS_TABLE_TENNIS = 250
|
||||||
@ -22,10 +24,23 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
self.hit_ball = False
|
self._hit_ball = False
|
||||||
self.ball_land_on_table = False
|
self._ball_land_on_table = False
|
||||||
|
self._ball_contact_after_hit = False
|
||||||
|
self._ball_return_success = False
|
||||||
|
self._ball_landing_pos = None
|
||||||
|
self._init_ball_state = None
|
||||||
|
self._episode_end = False
|
||||||
|
|
||||||
self._id_set = False
|
self._id_set = False
|
||||||
|
|
||||||
|
# reward calculation
|
||||||
self.ball_landing_pos = None
|
self.ball_landing_pos = None
|
||||||
|
self._goal_pos = np.zeros(2)
|
||||||
|
self._ball_traj = []
|
||||||
|
self._racket_traj = []
|
||||||
|
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
@ -40,11 +55,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
||||||
|
|
||||||
def _set_ids(self):
|
def _set_ids(self):
|
||||||
self._floor_contact_id = self.model.geom("floor").bodyid[0]
|
self._floor_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "floor")
|
||||||
self._ball_contact_id = self.model.geom("target_ball_contact").bodyid[0]
|
self._ball_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "target_ball_contact")
|
||||||
self._bat_front_id = self.model.geom("bat").bodyid[0]
|
self._bat_front_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "bat")
|
||||||
self._bat_back_id = self.model.geom("bat_back").bodyid[0]
|
self._bat_back_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "bat_back")
|
||||||
self._table_contact_id = self.model.geom("table_tennis_table").bodyid[0]
|
self._table_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "table_tennis_table")
|
||||||
self._id_set = True
|
self._id_set = True
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@ -53,40 +68,55 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
unstable_simulation = False
|
unstable_simulation = False
|
||||||
|
|
||||||
done = False
|
|
||||||
|
|
||||||
for _ in range(self.frame_skip):
|
for _ in range(self.frame_skip):
|
||||||
try:
|
try:
|
||||||
self.do_simulation(action, self.frame_skip)
|
self.do_simulation(action, 1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Simulation get unstable return with MujocoException: ", e)
|
print("Simulation get unstable return with MujocoException: ", e)
|
||||||
unstable_simulation = True
|
unstable_simulation = True
|
||||||
|
self._episode_end = True
|
||||||
|
break
|
||||||
|
|
||||||
if not self.hit_ball:
|
if not self._hit_ball:
|
||||||
self.hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
|
self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
|
||||||
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
||||||
if not self.hit_ball:
|
if not self._hit_ball:
|
||||||
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
||||||
if ball_land_on_floor_no_hit:
|
if ball_land_on_floor_no_hit:
|
||||||
self.ball_landing_pos = self.data.body("target_ball").xpos.copy()
|
self._ball_landing_pos = self.data.body("target_ball").xpos.copy()
|
||||||
done = True
|
self._episode_end = True
|
||||||
if self.hit_ball and not self.ball_contact_after_hit:
|
if self._hit_ball and not self._ball_contact_after_hit:
|
||||||
if not self.ball_contact_after_hit:
|
if not self._ball_contact_after_hit:
|
||||||
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
||||||
self.ball_contact_after_hit = True
|
self._ball_contact_after_hit = True
|
||||||
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
|
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||||
|
self._episode_end = True
|
||||||
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
||||||
self.ball_contact_after_hit = True
|
self._ball_contact_after_hit = True
|
||||||
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
|
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||||
if self.ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
||||||
self.ball_return_success = True
|
self._ball_return_success = True
|
||||||
|
self._episode_end = True
|
||||||
|
|
||||||
|
# update ball trajectory & racket trajectory
|
||||||
|
self._ball_traj.append(self.data.body("target_ball").xpos.copy())
|
||||||
|
self._racket_traj.append(self.data.geom("bat").xpos.copy())
|
||||||
|
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False
|
self._episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._episode_end
|
||||||
|
|
||||||
obs = self._get_obs()
|
reward = -25 if unstable_simulation else self._get_reward(self._episode_end)
|
||||||
|
|
||||||
return obs, 0., False, {}
|
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
||||||
|
if self._ball_landing_pos is not None else 10.
|
||||||
|
|
||||||
|
return self._get_obs(), reward, self._episode_end, {
|
||||||
|
"hit_ball": self._hit_ball,
|
||||||
|
"ball_returned_success": self._ball_return_success,
|
||||||
|
"land_dist_error": land_dist_err,
|
||||||
|
"is_success": self._ball_return_success and land_dist_err < 0.2,
|
||||||
|
"num_steps": self._steps,
|
||||||
|
}
|
||||||
|
|
||||||
def _contact_checker(self, id_1, id_2):
|
def _contact_checker(self, id_1, id_2):
|
||||||
for coni in range(0, self.data.ncon):
|
for coni in range(0, self.data.ncon):
|
||||||
@ -97,31 +127,94 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
new_context = self._sample_context()
|
self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
|
||||||
self.data.joint("tar_x").qpos = new_context[0]
|
self._goal_pos = self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:])
|
||||||
self.data.joint("tar_y").qpos = new_context[1]
|
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
||||||
self.data.joint("tar_z").qvel = 2.
|
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
||||||
|
self.data.joint("tar_z").qpos = self._init_ball_state[2]
|
||||||
|
self.data.joint("tar_x").qvel = self._init_ball_state[3]
|
||||||
|
self.data.joint("tar_y").qvel = self._init_ball_state[4]
|
||||||
|
self.data.joint("tar_z").qvel = self._init_ball_state[5]
|
||||||
|
|
||||||
self.ball_landing_pos = None
|
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
||||||
self.hit_ball = False
|
|
||||||
|
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
|
||||||
|
|
||||||
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
|
self._hit_ball = False
|
||||||
|
self._ball_land_on_table = False
|
||||||
|
self._ball_contact_after_hit = False
|
||||||
|
self._ball_return_success = False
|
||||||
|
self._ball_landing_pos = None
|
||||||
|
self._episode_end = False
|
||||||
|
self._ball_traj = []
|
||||||
|
self._racket_traj = []
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def _sample_context(self):
|
|
||||||
return self.np_random.uniform(low=self.context_bounds[0],
|
|
||||||
high=self.context_bounds[1])
|
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
obs = np.concatenate([
|
obs = np.concatenate([
|
||||||
self.data.qpos.flat[:7],
|
self.data.qpos.flat[:7].copy(),
|
||||||
self.data.qvel.flat[:7],
|
self.data.qvel.flat[:7].copy(),
|
||||||
|
self.data.joint("tar_x").qpos.copy(),
|
||||||
|
self.data.joint("tar_y").qpos.copy(),
|
||||||
|
self.data.joint("tar_z").qpos.copy(),
|
||||||
|
# self.data.body("target_ball").xvel.copy(),
|
||||||
|
self._goal_pos.copy(),
|
||||||
])
|
])
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
def get_obs(self):
|
||||||
|
return self._get_obs()
|
||||||
|
|
||||||
|
def _get_reward(self, episode_end):
|
||||||
|
if not episode_end:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
|
||||||
|
if not self._hit_ball:
|
||||||
|
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
||||||
|
else:
|
||||||
|
if self._ball_landing_pos is None:
|
||||||
|
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
|
||||||
|
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
||||||
|
else:
|
||||||
|
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
||||||
|
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
||||||
|
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
|
||||||
|
|
||||||
|
def _generate_random_ball(self, random_pos=False, random_vel=False):
|
||||||
|
x_pos, y_pos, z_pos = -0.5, 0.35, 1.75
|
||||||
|
x_vel, y_vel, z_vel = 2.5, 0., 0.5
|
||||||
|
if random_pos:
|
||||||
|
x_pos = self.np_random.uniform(low=self.context_bounds[0][0], high=self.context_bounds[1][0], size=1)
|
||||||
|
y_pos = self.np_random.uniform(low=self.context_bounds[0][1], high=self.context_bounds[1][1], size=1)
|
||||||
|
if random_vel:
|
||||||
|
x_vel = self.np_random.uniform(low=2.0, high=3.0, size=1)
|
||||||
|
init_ball_state = np.array([x_pos, y_pos, z_pos, x_vel, y_vel, z_vel])
|
||||||
|
return init_ball_state
|
||||||
|
|
||||||
|
def _generate_valid_init_ball(self, random_pos=False, random_vel=False):
|
||||||
|
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||||
|
while not check_init_state_validity(init_ball_state):
|
||||||
|
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||||
|
return init_ball_state
|
||||||
|
|
||||||
|
def check_traj_validity(self, traj):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_invalid_steps(self, traj):
|
||||||
|
penalty = -100
|
||||||
|
return self._get_obs(), penalty, True, {}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = TableTennisEnv()
|
env = TableTennisEnv()
|
||||||
env.reset()
|
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
for _ in range(200):
|
obs = env.reset()
|
||||||
|
for _ in range(2000):
|
||||||
env.render("human")
|
env.render("human")
|
||||||
env.step(env.action_space.sample())
|
obs, reward, done, info = env.step(np.zeros(7))
|
||||||
|
print(reward)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
@ -1 +1,51 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
jnt_pos_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
||||||
|
jnt_pos_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
||||||
|
delay_bound = [0.05, 0.3]
|
||||||
|
tau_bound = [0.5, 1.5]
|
||||||
|
|
||||||
|
net_height = 0.1
|
||||||
|
table_height = 0.77
|
||||||
|
table_x_min = -1.1
|
||||||
|
table_x_max = 1.1
|
||||||
|
table_y_min = -0.6
|
||||||
|
table_y_max = 0.6
|
||||||
|
g = 9.81
|
||||||
|
|
||||||
|
def check_init_state_validity(init_state):
|
||||||
|
assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state)
|
||||||
|
x = init_state[0]
|
||||||
|
y = init_state[1]
|
||||||
|
z = init_state[2] - table_height + 0.1
|
||||||
|
v_x = init_state[3]
|
||||||
|
v_y = init_state[4]
|
||||||
|
v_z = init_state[5]
|
||||||
|
|
||||||
|
# check if the initial state is wrong
|
||||||
|
if x > -0.2:
|
||||||
|
return False
|
||||||
|
# check if the ball velocity direction is wrong
|
||||||
|
if v_x < 0.:
|
||||||
|
return False
|
||||||
|
# check if the ball can pass the net
|
||||||
|
t_n = (-2.*(-v_z)/g + np.sqrt(4*(v_z**2)/g**2 - 8*(net_height-z)/g))/2.
|
||||||
|
if x + v_x * t_n < 0.05:
|
||||||
|
return False
|
||||||
|
# check if ball landing position will violate x bounds
|
||||||
|
t_l = (-2.*(-v_z)/g + np.sqrt(4*(v_z**2)/g**2 + 8*(z)/g))/2.
|
||||||
|
if x + v_x * t_l > table_x_max:
|
||||||
|
return False
|
||||||
|
# check if ball landing position will violate y bounds
|
||||||
|
if y + v_y * t_l > table_y_max or y + v_y * t_l < table_y_min:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def magnus_force(top_spin=0.0, side_spin=0.0, v_ball=np.zeros(3), v_wind=np.zeros(3)):
|
||||||
|
rho = 1.225 # Air density
|
||||||
|
A = 1.256 * 10e-3 # Cross-section area of ball
|
||||||
|
C_l = 4.68 * 10e-4 - 2.0984 * 10e-5 * (np.linalg.norm(v_ball) - 50) # Lift force coeffient or simply 1.23
|
||||||
|
w = np.array([0.0, top_spin, side_spin]) # Angular velocity of ball
|
||||||
|
f_m = 0.5 * rho * A * C_l * np.linalg.norm(v_ball-v_wind) * np.cross(w, v_ball-v_wind)
|
||||||
|
return f_m
|
||||||
|
@ -157,17 +157,18 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = True
|
||||||
# DMP
|
# DMP
|
||||||
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
# ProMP
|
# ProMP
|
||||||
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||||
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
example_mp("TableTennis4DProMP-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
# ProDMP
|
# ProDMP
|
||||||
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
# example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Custom MP
|
# Custom MP
|
||||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||||
|
Loading…
Reference in New Issue
Block a user