This commit is contained in:
Hongyi Zhou 2022-11-04 21:22:32 +01:00
parent 7b2451d317
commit 5a547d85f9
9 changed files with 336 additions and 103 deletions

View File

@ -23,7 +23,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
replanning_schedule: Optional[ replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum, reward_aggregation: Callable[[np.ndarray], float] = np.sum,
max_planning_times: int = 1, max_planning_times = None,
desired_conditioning: bool = False desired_conditioning: bool = False
): ):
""" """
@ -163,8 +163,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# action = np.concatenate((basis_weights, goal_weights), axis=1).flatten() # action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
# TODO remove this part, right now only needed for beer pong # TODO remove this part, right now only needed for beer pong
mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen) # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
position, velocity = self.get_trajectory(mp_params) position, velocity = self.get_trajectory(action)
traj_is_valid = self.env.episode_callback(action, position, velocity)
trajectory_length = len(position) trajectory_length = len(position)
rewards = np.zeros(shape=(trajectory_length,)) rewards = np.zeros(shape=(trajectory_length,))
@ -176,6 +177,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict() infos = dict()
done = False done = False
if self.verbose >= 2:
desired_pos_traj = []
desired_vel_traj = []
pos_traj = []
vel_traj = []
if traj_is_valid:
self.plan_counts += 1 self.plan_counts += 1
for t, (pos, vel) in enumerate(zip(position, velocity)): for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
@ -192,14 +200,20 @@ class BlackBoxWrapper(gym.ObservationWrapper):
elems[t] = v elems[t] = v
infos[k] = elems infos[k] = elems
if self.verbose >= 2:
desired_pos_traj.append(pos)
desired_vel_traj.append(vel)
pos_traj.append(self.current_pos)
vel_traj.append(self.current_vel)
if self.render_kwargs: if self.render_kwargs:
self.env.render(**self.render_kwargs) self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps): t + 1 + self.current_traj_steps):
if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times: # if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
continue # continue
self.condition_pos = pos if self.desired_conditioning else self.current_pos self.condition_pos = pos if self.desired_conditioning else self.current_pos
self.condition_vel = vel if self.desired_conditioning else self.current_vel self.condition_vel = vel if self.desired_conditioning else self.current_vel
@ -217,11 +231,17 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos['step_actions'] = actions[:t + 1] infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1] infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1] infos['step_rewards'] = rewards[:t + 1]
infos['desired_pos_traj'] = np.array(desired_pos_traj)
infos['desired_vel_traj'] = np.array(desired_vel_traj)
infos['pos_traj'] = np.array(pos_traj)
infos['vel_traj'] = np.array(vel_traj)
infos['trajectory_length'] = t + 1 infos['trajectory_length'] = t + 1
trajectory_return = self.reward_aggregation(rewards[:t + 1]) trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.
This only needs to be called once""" This only needs to be called once"""

View File

@ -52,8 +52,7 @@ class RawInterfaceWrapper(gym.Wrapper):
""" """
return self.env.dt return self.env.dt
def episode_callback(self, action: np.ndarray, traj_gen: MPInterface) -> Tuple[ def episode_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.array) -> Tuple[bool]:
np.ndarray, Union[np.ndarray, None]]:
""" """
Used to extract the parameters for the movement primitive and other parameters from an action array which might Used to extract the parameters for the movement primitive and other parameters from an action array which might
include other actions like ball releasing time for the beer pong environment. include other actions like ball releasing time for the beer pong environment.
@ -65,4 +64,11 @@ class RawInterfaceWrapper(gym.Wrapper):
Returns: Returns:
Tuple: mp_arguments and other arguments Tuple: mp_arguments and other arguments
""" """
return action, None return True
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
"""
Used to return a fake return from the environment if the desired trajectory is invalid.
"""
obs = np.zeros(1)
return obs, 0, True, {}

View File

@ -28,7 +28,9 @@ DEFAULT_BB_DICT_ProMP = {
'trajectory_generator_type': 'promp' 'trajectory_generator_type': 'promp'
}, },
"phase_generator_kwargs": { "phase_generator_kwargs": {
'phase_generator_type': 'linear' 'phase_generator_type': 'linear',
'learn_tau': False,
'learn_delay': False,
}, },
"controller_kwargs": { "controller_kwargs": {
'controller_type': 'motor', 'controller_type': 'motor',
@ -40,6 +42,8 @@ DEFAULT_BB_DICT_ProMP = {
'num_basis': 5, 'num_basis': 5,
'num_basis_zero_start': 1, 'num_basis_zero_start': 1,
'basis_bandwidth_factor': 3.0, 'basis_bandwidth_factor': 3.0,
},
"black_box_kwargs": {
} }
} }
@ -245,6 +249,18 @@ register(
max_episode_steps=FIXED_RELEASE_STEP, max_episode_steps=FIXED_RELEASE_STEP,
) )
# Table Tennis environments
for ctxt_dim in [2, 4]:
register(
id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=350,
kwargs={
"ctxt_dim": ctxt_dim,
'frame_skip': 4
}
)
# movement Primitive Environments # movement Primitive Environments
## Simple Reacher ## Simple Reacher
@ -515,6 +531,29 @@ for _v in _versions:
kwargs=kwargs_dict_box_pushing_prodmp kwargs=kwargs_dict_box_pushing_prodmp
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
## Table Tennis
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0']
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_tt_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.MPWrapper)
kwargs_dict_tt_promp['name'] = _v
kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_tt_promp['black_box_kwargs']['duration'] = 2.
kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_tt_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# #
# ## Walker2DJump # ## Walker2DJump
# _versions = ['Walker2DJump-v0'] # _versions = ['Walker2DJump-v0']

View File

@ -8,3 +8,4 @@ from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
from .reacher.reacher import ReacherEnv from .reacher.reacher import ReacherEnv
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
from .table_tennis.table_tennis_env import TableTennisEnv

View File

@ -28,10 +28,10 @@ class MPWrapper(RawInterfaceWrapper):
return self.data.qvel[0:7].copy() return self.data.qvel[0:7].copy()
# TODO: Fix this # TODO: Fix this
def episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]: def episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None], bool]:
if mp.learn_tau: if mp.learn_tau:
self.release_step = action[0] / self.dt # Tau value self.release_step = action[0] / self.dt # Tau value
return action, None return action, None, True
def set_context(self, context): def set_context(self, context):
xyz = np.zeros(3) xyz = np.zeros(3)

View File

@ -3,6 +3,7 @@ from typing import Union, Tuple
import numpy as np import numpy as np
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
class MPWrapper(RawInterfaceWrapper): class MPWrapper(RawInterfaceWrapper):
@ -13,10 +14,8 @@ class MPWrapper(RawInterfaceWrapper):
return np.hstack([ return np.hstack([
[False] * 7, # joints position [False] * 7, # joints position
[False] * 7, # joints velocity [False] * 7, # joints velocity
[False] * 3, # position of box [False] * 3, # position ball
[False] * 4, # orientation of box [True] * 2, # target landing position
[True] * 3, # position of target
[True] * 4, # orientation of target
# [True] * 1, # time # [True] * 1, # time
]) ])
@ -27,3 +26,27 @@ class MPWrapper(RawInterfaceWrapper):
@property @property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.data.qvel[:7].copy() return self.data.qvel[:7].copy()
def episode_callback(self, action, pos_traj, vel_traj):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return False
return True
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])
delay_invalid_penalty = np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])
violate_high_bound_error = np.sum(np.maximum(pos_traj - jnt_pos_high, 0))
violate_low_bound_error = np.sum(np.maximum(jnt_pos_low - pos_traj, 0))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
violate_high_bound_error + violate_low_bound_error
return self.get_obs(), -invalid_penalty, True, {
"hit_ball": [False],
"ball_returned_success": [False],
"land_dist_error": [10.],
"is_success": [False],
'trajectory_length': 1,
"num_steps": [1]
}

View File

@ -4,6 +4,8 @@ import numpy as np
from gym import utils, spaces from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv from gym.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
import mujoco import mujoco
MAX_EPISODE_STEPS_TABLE_TENNIS = 250 MAX_EPISODE_STEPS_TABLE_TENNIS = 250
@ -22,10 +24,23 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
self.hit_ball = False self._hit_ball = False
self.ball_land_on_table = False self._ball_land_on_table = False
self._ball_contact_after_hit = False
self._ball_return_success = False
self._ball_landing_pos = None
self._init_ball_state = None
self._episode_end = False
self._id_set = False self._id_set = False
# reward calculation
self.ball_landing_pos = None self.ball_landing_pos = None
self._goal_pos = np.zeros(2)
self._ball_traj = []
self._racket_traj = []
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip, frame_skip=frame_skip,
@ -40,11 +55,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
def _set_ids(self): def _set_ids(self):
self._floor_contact_id = self.model.geom("floor").bodyid[0] self._floor_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "floor")
self._ball_contact_id = self.model.geom("target_ball_contact").bodyid[0] self._ball_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "target_ball_contact")
self._bat_front_id = self.model.geom("bat").bodyid[0] self._bat_front_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "bat")
self._bat_back_id = self.model.geom("bat_back").bodyid[0] self._bat_back_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "bat_back")
self._table_contact_id = self.model.geom("table_tennis_table").bodyid[0] self._table_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "table_tennis_table")
self._id_set = True self._id_set = True
def step(self, action): def step(self, action):
@ -53,40 +68,55 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False unstable_simulation = False
done = False
for _ in range(self.frame_skip): for _ in range(self.frame_skip):
try: try:
self.do_simulation(action, self.frame_skip) self.do_simulation(action, 1)
except Exception as e: except Exception as e:
print("Simulation get unstable return with MujocoException: ", e) print("Simulation get unstable return with MujocoException: ", e)
unstable_simulation = True unstable_simulation = True
self._episode_end = True
break
if not self.hit_ball: if not self._hit_ball:
self.hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \ self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
self._contact_checker(self._ball_contact_id, self._bat_back_id) self._contact_checker(self._ball_contact_id, self._bat_back_id)
if not self.hit_ball: if not self._hit_ball:
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id) ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
if ball_land_on_floor_no_hit: if ball_land_on_floor_no_hit:
self.ball_landing_pos = self.data.body("target_ball").xpos.copy() self._ball_landing_pos = self.data.body("target_ball").xpos.copy()
done = True self._episode_end = True
if self.hit_ball and not self.ball_contact_after_hit: if self._hit_ball and not self._ball_contact_after_hit:
if not self.ball_contact_after_hit: if not self._ball_contact_after_hit:
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
self.ball_contact_after_hit = True self._ball_contact_after_hit = True
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy() self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
self._episode_end = True
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
self.ball_contact_after_hit = True self._ball_contact_after_hit = True
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy() self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
if self.ball_landing_pos[0] < 0.: # ball lands on the opponent side if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
self.ball_return_success = True self._ball_return_success = True
self._episode_end = True
# update ball trajectory & racket trajectory
self._ball_traj.append(self.data.body("target_ball").xpos.copy())
self._racket_traj.append(self.data.geom("bat").xpos.copy())
self._steps += 1 self._steps += 1
episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False self._episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._episode_end
obs = self._get_obs() reward = -25 if unstable_simulation else self._get_reward(self._episode_end)
return obs, 0., False, {} land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
if self._ball_landing_pos is not None else 10.
return self._get_obs(), reward, self._episode_end, {
"hit_ball": self._hit_ball,
"ball_returned_success": self._ball_return_success,
"land_dist_error": land_dist_err,
"is_success": self._ball_return_success and land_dist_err < 0.2,
"num_steps": self._steps,
}
def _contact_checker(self, id_1, id_2): def _contact_checker(self, id_1, id_2):
for coni in range(0, self.data.ncon): for coni in range(0, self.data.ncon):
@ -97,31 +127,94 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self._steps = 0 self._steps = 0
new_context = self._sample_context() self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
self.data.joint("tar_x").qpos = new_context[0] self._goal_pos = self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:])
self.data.joint("tar_y").qpos = new_context[1] self.data.joint("tar_x").qpos = self._init_ball_state[0]
self.data.joint("tar_z").qvel = 2. self.data.joint("tar_y").qpos = self._init_ball_state[1]
self.data.joint("tar_z").qpos = self._init_ball_state[2]
self.data.joint("tar_x").qvel = self._init_ball_state[3]
self.data.joint("tar_y").qvel = self._init_ball_state[4]
self.data.joint("tar_z").qvel = self._init_ball_state[5]
self.ball_landing_pos = None self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
self.hit_ball = False
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
mujoco.mj_forward(self.model, self.data)
self._hit_ball = False
self._ball_land_on_table = False
self._ball_contact_after_hit = False
self._ball_return_success = False
self._ball_landing_pos = None
self._episode_end = False
self._ball_traj = []
self._racket_traj = []
return self._get_obs() return self._get_obs()
def _sample_context(self):
return self.np_random.uniform(low=self.context_bounds[0],
high=self.context_bounds[1])
def _get_obs(self): def _get_obs(self):
obs = np.concatenate([ obs = np.concatenate([
self.data.qpos.flat[:7], self.data.qpos.flat[:7].copy(),
self.data.qvel.flat[:7], self.data.qvel.flat[:7].copy(),
self.data.joint("tar_x").qpos.copy(),
self.data.joint("tar_y").qpos.copy(),
self.data.joint("tar_z").qpos.copy(),
# self.data.body("target_ball").xvel.copy(),
self._goal_pos.copy(),
]) ])
return obs return obs
def get_obs(self):
return self._get_obs()
def _get_reward(self, episode_end):
if not episode_end:
return 0
else:
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
if not self._hit_ball:
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
else:
if self._ball_landing_pos is None:
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
else:
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
over_net_bonus = int(self._ball_landing_pos[0] < 0)
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
def _generate_random_ball(self, random_pos=False, random_vel=False):
x_pos, y_pos, z_pos = -0.5, 0.35, 1.75
x_vel, y_vel, z_vel = 2.5, 0., 0.5
if random_pos:
x_pos = self.np_random.uniform(low=self.context_bounds[0][0], high=self.context_bounds[1][0], size=1)
y_pos = self.np_random.uniform(low=self.context_bounds[0][1], high=self.context_bounds[1][1], size=1)
if random_vel:
x_vel = self.np_random.uniform(low=2.0, high=3.0, size=1)
init_ball_state = np.array([x_pos, y_pos, z_pos, x_vel, y_vel, z_vel])
return init_ball_state
def _generate_valid_init_ball(self, random_pos=False, random_vel=False):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
while not check_init_state_validity(init_ball_state):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state
def check_traj_validity(self, traj):
raise NotImplementedError
def get_invalid_steps(self, traj):
penalty = -100
return self._get_obs(), penalty, True, {}
if __name__ == "__main__": if __name__ == "__main__":
env = TableTennisEnv() env = TableTennisEnv()
env.reset()
for _ in range(1000): for _ in range(1000):
for _ in range(200): obs = env.reset()
for _ in range(2000):
env.render("human") env.render("human")
env.step(env.action_space.sample()) obs, reward, done, info = env.step(np.zeros(7))
print(reward)
if done:
break

View File

@ -1 +1,51 @@
import numpy as np
jnt_pos_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
jnt_pos_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
delay_bound = [0.05, 0.3]
tau_bound = [0.5, 1.5]
net_height = 0.1
table_height = 0.77
table_x_min = -1.1
table_x_max = 1.1
table_y_min = -0.6
table_y_max = 0.6
g = 9.81
def check_init_state_validity(init_state):
assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state)
x = init_state[0]
y = init_state[1]
z = init_state[2] - table_height + 0.1
v_x = init_state[3]
v_y = init_state[4]
v_z = init_state[5]
# check if the initial state is wrong
if x > -0.2:
return False
# check if the ball velocity direction is wrong
if v_x < 0.:
return False
# check if the ball can pass the net
t_n = (-2.*(-v_z)/g + np.sqrt(4*(v_z**2)/g**2 - 8*(net_height-z)/g))/2.
if x + v_x * t_n < 0.05:
return False
# check if ball landing position will violate x bounds
t_l = (-2.*(-v_z)/g + np.sqrt(4*(v_z**2)/g**2 + 8*(z)/g))/2.
if x + v_x * t_l > table_x_max:
return False
# check if ball landing position will violate y bounds
if y + v_y * t_l > table_y_max or y + v_y * t_l < table_y_min:
return False
return True
def magnus_force(top_spin=0.0, side_spin=0.0, v_ball=np.zeros(3), v_wind=np.zeros(3)):
rho = 1.225 # Air density
A = 1.256 * 10e-3 # Cross-section area of ball
C_l = 4.68 * 10e-4 - 2.0984 * 10e-5 * (np.linalg.norm(v_ball) - 50) # Lift force coeffient or simply 1.23
w = np.array([0.0, top_spin, side_spin]) # Angular velocity of ball
f_m = 0.5 * rho * A * C_l * np.linalg.norm(v_ball-v_wind) * np.cross(w, v_ball-v_wind)
return f_m

View File

@ -157,17 +157,18 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = True render = True
# DMP # DMP
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
# ProMP # ProMP
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
example_mp("TableTennis4DProMP-v0", seed=10, iterations=5, render=render)
# ProDMP # ProDMP
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render) # example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
# Altered basis functions # Altered basis functions
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
# Custom MP # Custom MP
example_fully_custom_mp(seed=10, iterations=1, render=render) # example_fully_custom_mp(seed=10, iterations=1, render=render)