add invalid trajectory callback & invalid traj return & register all 3 variantes of table tennis tasks

This commit is contained in:
Hongyi Zhou 2022-12-01 11:28:03 +01:00
parent 28aa430fd2
commit f376772c22
9 changed files with 151 additions and 162 deletions

View File

@ -63,14 +63,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.traj_gen_action_space = self._get_traj_gen_action_space() self.traj_gen_action_space = self._get_traj_gen_action_space()
self.action_space = self._get_action_space() self.action_space = self._get_action_space()
# no goal learning
# tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32)
self.action_space.low[0] = 0.8
self.action_space.high[0] = 1.5
self.action_space.low[1] = 0.05
self.action_space.high[1] = 0.15
self.observation_space = self._get_observation_space() self.observation_space = self._get_observation_space()
# rendering # rendering
@ -93,8 +85,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
return observation.astype(self.observation_space.dtype) return observation.astype(self.observation_space.dtype)
def get_trajectory(self, action: np.ndarray) -> Tuple: def get_trajectory(self, action: np.ndarray) -> Tuple:
# duration = self.duration duration = self.duration
duration = self.duration - self.current_traj_steps * self.dt # duration = self.duration - self.current_traj_steps * self.dt
if self.learn_sub_trajectories: if self.learn_sub_trajectories:
duration = None duration = None
# reset with every new call as we need to set all arguments, such as tau, delay, again. # reset with every new call as we need to set all arguments, such as tau, delay, again.
@ -157,8 +149,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# TODO remove this part, right now only needed for beer pong # TODO remove this part, right now only needed for beer pong
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen) # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
position, velocity = self.get_trajectory(action) position, velocity = self.get_trajectory(action)
traj_is_valid = self.env.episode_callback(action, position, velocity) traj_is_valid = self.env.preprocessing_and_validity_callback(action, position, velocity)
# insert validation here
trajectory_length = len(position) trajectory_length = len(position)
rewards = np.zeros(shape=(trajectory_length,)) rewards = np.zeros(shape=(trajectory_length,))
if self.verbose >= 2: if self.verbose >= 2:
@ -169,7 +161,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict() infos = dict()
done = False done = False
if traj_is_valid: if not traj_is_valid:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
self.return_context_observation)
return self.observation(obs), trajectory_return, done, infos
else:
self.plan_steps += 1 self.plan_steps += 1
for t, (pos, vel) in enumerate(zip(position, velocity)): for t, (pos, vel) in enumerate(zip(position, velocity)):
current_pos = self.current_pos current_pos = self.current_pos
@ -215,10 +211,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos['trajectory_length'] = t + 1 infos['trajectory_length'] = t + 1
trajectory_return = self.reward_aggregation(rewards[:t + 1]) trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
self.return_context_observation)
return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.

View File

@ -52,6 +52,19 @@ class RawInterfaceWrapper(gym.Wrapper):
""" """
return self.env.dt return self.env.dt
def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[bool, np.ndarray, np.ndarray]:
"""
Used to preprocess the action and check if the desired trajectory is valid.
"""
return True, pos_traj, vel_traj
def set_episode_arguments(self, action, pos_traj, vel_traj):
"""
Used to set the arguments for env that valid for the whole episode
"""
return pos_traj, vel_traj
def episode_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.array) -> Tuple[bool]: def episode_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.array) -> Tuple[bool]:
""" """
Used to extract the parameters for the movement primitive and other parameters from an action array which might Used to extract the parameters for the movement primitive and other parameters from an action array which might
@ -68,7 +81,6 @@ class RawInterfaceWrapper(gym.Wrapper):
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
""" """
Used to return a fake return from the environment if the desired trajectory is invalid. Used to return a artificial return from the env if the desired trajectory is invalid.
""" """
obs = np.zeros(1) return np.zeros(1), 0, True, {}
return obs, 0, True, {}

View File

@ -18,6 +18,8 @@ from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \ from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \
BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING
from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching, \
MAX_EPISODE_STEPS_TABLE_TENNIS
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []} ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
@ -248,17 +250,28 @@ for ctxt_dim in [2, 4]:
register( register(
id='TableTennis{}D-v0'.format(ctxt_dim), id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv', entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=350, max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
kwargs={ kwargs={
"ctxt_dim": ctxt_dim, "ctxt_dim": ctxt_dim,
'frame_skip': 4, 'frame_skip': 4,
'enable_wind': False, 'goal_switching_step': None,
'enable_switching_goal': False, 'enable_artificial_wind': False,
'enable_air': False,
'enable_artifical_wind': False,
} }
) )
register(
id='TableTennisWind-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisWind',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
)
register(
id='TableTennisGoalSwitching-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisGoalSwitching',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
)
# movement Primitive Environments # movement Primitive Environments
## Simple Reacher ## Simple Reacher
@ -529,17 +542,22 @@ for _v in _versions:
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
## Table Tennis ## Table Tennis
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0'] _versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0']
for _v in _versions: for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_tt_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_tt_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.MPWrapper) if _v == 'TableTennisWind-v0':
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
else:
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.TT_MPWrapper)
kwargs_dict_tt_promp['name'] = _v kwargs_dict_tt_promp['name'] = _v
kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
kwargs_dict_tt_promp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
kwargs_dict_tt_promp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15]
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3 kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2 kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1 kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1
@ -556,7 +574,10 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProDMP-{_name[1]}' _env_id = f'{_name[0]}ProDMP-{_name[1]}'
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper) if _v == 'TableTennisWind-v0':
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
else:
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TT_MPWrapper)
kwargs_dict_tt_prodmp['name'] = _v kwargs_dict_tt_prodmp['name'] = _v
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
@ -564,12 +585,13 @@ for _v in _versions:
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0 kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0 kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
kwargs_dict_tt_prodmp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
kwargs_dict_tt_prodmp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15]
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2 kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25. kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
#kwargs_dict_tt_prodmp['basis_generator_kwargs']['pre_compute_length_factor'] = 5
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3 kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0 kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0

View File

@ -8,4 +8,4 @@ from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
from .reacher.reacher import ReacherEnv from .reacher.reacher import ReacherEnv
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
from .table_tennis.table_tennis_env import TableTennisEnv from .table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper from .mp_wrapper import TT_MPWrapper, TTVelObs_MPWrapper

View File

@ -6,7 +6,7 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
class MPWrapper(RawInterfaceWrapper): class TT_MPWrapper(RawInterfaceWrapper):
# Random x goal + random init pos # Random x goal + random init pos
@property @property
@ -29,48 +29,26 @@ class MPWrapper(RawInterfaceWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.data.qvel[:7].copy() return self.data.qvel[:7].copy()
def check_time_validity(self, action): def preprocessing_and_validity_callback(self, action, pos_traj, vel_traj):
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \ return self.check_traj_validity(action, pos_traj, vel_traj)
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \ def set_episode_arguments(self, action, pos_traj, vel_traj):
-> Tuple[np.ndarray, float, bool, dict]: return pos_traj, vel_traj
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
obs = np.concatenate([self.get_obs(), np.array([0])])
return obs, -invalid_penalty, True, {
"hit_ball": [False],
"ball_returned_success": [False],
"land_dist_error": [10.],
"is_success": [False],
'trajectory_length': 1,
"num_steps": [1]
}
def episode_callback(self, action, pos_traj, vel_traj): def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return False
return True
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray,
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]: return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) return self.get_invalid_traj_step_return(action, pos_traj, vel_traj, return_contextual_obs)
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) class TTVelObs_MPWrapper(TT_MPWrapper):
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ @property
violate_high_bound_error + violate_low_bound_error def context_mask(self):
obs = np.concatenate([self.get_obs(), np.array([0])]) return np.hstack([
if return_contextual_obs: [False] * 7, # joints position
obs = self.get_obs() [False] * 7, # joints velocity
return obs, -invalid_penalty, True, { [True] * 2, # position ball x, y
"hit_ball": [False], [False] * 1, # position ball z
"ball_returned_success": [False], [True] * 3, # velocity ball x, y, z
"land_dist_error": [10.], [True] * 2, # target landing position
"is_success": [False], # [True] * 1, # time
'trajectory_length': 1, ])
"num_steps": [1]
}

View File

@ -5,6 +5,7 @@ from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv from gym.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
import mujoco import mujoco
@ -21,13 +22,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
""" """
7 DoF table tennis environment 7 DoF table tennis environment
""" """
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
enable_switching_goal: bool = False, goal_switching_step: int = None,
enable_wind: bool = False, enable_artificial_wind: bool = False):
enable_artifical_wind: bool = False,
enable_magnus: bool = False,
enable_air: bool = False):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
@ -47,12 +44,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._ball_traj = [] self._ball_traj = []
self._racket_traj = [] self._racket_traj = []
self._goal_switching_step = goal_switching_step
self._enable_goal_switching = enable_switching_goal self._enable_artificial_wind = enable_artificial_wind
self._enable_artifical_wind = enable_artifical_wind self._artificial_force = 0.
self._artifical_force = 0.
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
@ -62,20 +58,13 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.context_bounds = CONTEXT_BOUNDS_2DIMS self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4: elif ctxt_dim == 4:
self.context_bounds = CONTEXT_BOUNDS_4DIMS self.context_bounds = CONTEXT_BOUNDS_4DIMS
if self._enable_goal_switching: if self._goal_switching_step is not None:
self.context_bounds = CONTEXT_BOUNDS_SWICHING self.context_bounds = CONTEXT_BOUNDS_SWICHING
else: else:
raise NotImplementedError raise NotImplementedError
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
# complex dynamics settings
if enable_air:
self.model.opt.density = 1.225
self.model.opt.viscosity = 2.27e-5
self._enable_wind = enable_wind
self._enable_magnus = enable_magnus
self._wind_vel = np.zeros(3) self._wind_vel = np.zeros(3)
def _set_ids(self): def _set_ids(self):
@ -92,9 +81,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False unstable_simulation = False
if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5:
if self._enable_goal_switching:
if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5:
new_goal_pos = self._generate_goal_pos(random=True) new_goal_pos = self._generate_goal_pos(random=True)
new_goal_pos[1] = -new_goal_pos[1] new_goal_pos[1] = -new_goal_pos[1]
self._goal_pos = new_goal_pos self._goal_pos = new_goal_pos
@ -102,8 +89,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
mujoco.mj_forward(self.model, self.data) mujoco.mj_forward(self.model, self.data)
for _ in range(self.frame_skip): for _ in range(self.frame_skip):
if self._enable_artifical_wind: if self._enable_artificial_wind:
self.data.qfrc_applied[-2] = self._artifical_force self.data.qfrc_applied[-2] = self._artificial_force
try: try:
self.do_simulation(action, 1) self.do_simulation(action, 1)
except Exception as e: except Exception as e:
@ -163,7 +150,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self._steps = 0 self._steps = 0
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False) self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
# self._init_ball_state[2] = 1.85
self._goal_pos = self._generate_goal_pos(random=True) self._goal_pos = self._generate_goal_pos(random=True)
self.data.joint("tar_x").qpos = self._init_ball_state[0] self.data.joint("tar_x").qpos = self._init_ball_state[0]
self.data.joint("tar_y").qpos = self._init_ball_state[1] self.data.joint("tar_y").qpos = self._init_ball_state[1]
@ -172,19 +158,16 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.data.joint("tar_y").qvel = self._init_ball_state[4] self.data.joint("tar_y").qvel = self._init_ball_state[4]
self.data.joint("tar_z").qvel = self._init_ball_state[5] self.data.joint("tar_z").qvel = self._init_ball_state[5]
if self._enable_artifical_wind: if self._enable_artificial_wind:
self._artifical_force = self.np_random.uniform(low=-0.1, high=0.1) self._artificial_force = self.np_random.uniform(low=-0.1, high=0.1)
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]]) self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5]) self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
self.data.qvel[:7] = np.zeros(7)
mujoco.mj_forward(self.model, self.data) mujoco.mj_forward(self.model, self.data)
if self._enable_wind:
self._wind_vel[1] = self.np_random.uniform(low=-10, high=10, size=1)
self.model.opt.wind[:3] = self._wind_vel
self._hit_ball = False self._hit_ball = False
self._ball_land_on_table = False self._ball_land_on_table = False
self._ball_contact_after_hit = False self._ball_contact_after_hit = False
@ -208,10 +191,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.data.joint("tar_x").qpos.copy(), self.data.joint("tar_x").qpos.copy(),
self.data.joint("tar_y").qpos.copy(), self.data.joint("tar_y").qpos.copy(),
self.data.joint("tar_z").qpos.copy(), self.data.joint("tar_z").qpos.copy(),
#self.data.joint("tar_x").qvel.copy(), # self.data.joint("tar_x").qvel.copy(),
#self.data.joint("tar_y").qvel.copy(), # self.data.joint("tar_y").qvel.copy(),
#self.data.joint("tar_z").qvel.copy(), # self.data.joint("tar_z").qvel.copy(),
# self.data.body("target_ball").xvel.copy(),
self._goal_pos.copy(), self._goal_pos.copy(),
]) ])
return obs return obs
@ -252,56 +234,54 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state return init_ball_state
def plot_ball_traj(x_traj, y_traj, z_traj): def _get_traj_invalid_reward(self, action, pos_traj, vel_traj):
import matplotlib.pyplot as plt tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
fig = plt.figure() delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
ax = fig.add_subplot(111, projection='3d') violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
ax.plot(x_traj, y_traj, z_traj) violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
plt.show() invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
violate_high_bound_error + violate_low_bound_error
return -invalid_penalty
def plot_ball_traj_2d(x_traj, y_traj): def get_invalid_traj_step_return(self, action, pos_traj, vel_traj, contextual_obs):
import matplotlib.pyplot as plt obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
fig = plt.figure() penalty = self._get_traj_invalid_reward(action, pos_traj, vel_traj)
ax = fig.add_subplot(111) return obs, penalty, True, {
ax.plot(x_traj, y_traj) "hit_ball": [False],
plt.show() "ball_return_success": [False],
"land_dist_error": [False],
"trajectory_length": 1,
"num_steps": [1],
}
def plot_compare_trajs(traj1, traj2, title): @staticmethod
import matplotlib.pyplot as plt def check_traj_validity(action, pos_traj, vel_traj):
fig = plt.figure() time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
ax = fig.add_subplot(111) or action[1] > delay_bound[1] or action[1] < delay_bound[0]
ax.plot(traj1, color='r', label='traj1') if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
ax.plot(traj2, color='b', label='traj2') return False, pos_traj, vel_traj
ax.set_title(title) return True, pos_traj, vel_traj
plt.legend()
plt.show()
if __name__ == "__main__":
env_air = TableTennisEnv(enable_air=False, enable_wind=False, enable_artifical_wind=True) class TableTennisWind(TableTennisEnv):
env_no_air = TableTennisEnv(enable_air=False, enable_wind=False) def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
for _ in range(10): super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
obs1 = env_air.reset()
obs2 = env_no_air.reset() def _get_obs(self):
# obs2 = env_with_air.reset() obs = np.concatenate([
air_x_pos = [] self.data.qpos.flat[:7].copy(),
no_air_x_pos = [] self.data.qvel.flat[:7].copy(),
# y_pos = [] self.data.joint("tar_x").qpos.copy(),
# z_pos = [] self.data.joint("tar_y").qpos.copy(),
# x_vel = [] self.data.joint("tar_z").qpos.copy(),
# y_vel = [] self.data.joint("tar_x").qvel.copy(),
# z_vel = [] self.data.joint("tar_y").qvel.copy(),
for _ in range(2000): self.data.joint("tar_z").qvel.copy(),
env_air.render("human") self._goal_pos.copy(),
obs1, reward1, done1, info1 = env_air.step(np.zeros(7)) ])
obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7)) return obs
# # _, _, _, _ = env_no_air.step(np.zeros(7))
air_x_pos.append(env_air.data.joint("tar_z").qpos[0])
no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0]) class TableTennisGoalSwitching(TableTennisEnv):
# # z_pos.append(env.data.joint("tar_z").qpos[0]) def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99):
# # x_vel.append(env.data.joint("tar_x").qvel[0]) super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step)
# # y_vel.append(env.data.joint("tar_y").qvel[0])
# # z_vel.append(env.data.joint("tar_z").qvel[0])
# # print(reward)
if info1["num_steps"] == 150:
plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out wind")
break

View File

@ -155,7 +155,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = True render = False
# DMP # DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
@ -163,10 +163,15 @@ if __name__ == '__main__':
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
# example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render) # example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
# example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render)
# example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render)
# ProDMP # ProDMP
# example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render) # example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
example_mp("TableTennis4DProDMP-v0", seed=10, iterations=20, render=render) example_mp("TableTennis4DProDMP-v0", seed=10, iterations=2000, render=render)
# example_mp("TableTennisWindProDMP-v0", seed=10, iterations=20, render=render)
# example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=20, render=render)
# Altered basis functions # Altered basis functions
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)

View File

@ -168,11 +168,11 @@ def make_bb(
# set tau bounds to minimum of two env steps otherwise computing the velocity is not possible. # set tau bounds to minimum of two env steps otherwise computing the velocity is not possible.
# maximum is full duration of one episode. # maximum is full duration of one episode.
if phase_kwargs.get('learn_tau'): if phase_kwargs.get('learn_tau') and phase_kwargs.get('tau_bound') is None:
phase_kwargs["tau_bound"] = [env.dt * 2, black_box_kwargs['duration']] phase_kwargs["tau_bound"] = [env.dt * 2, black_box_kwargs['duration']]
# Max delay is full duration minus two steps due to above reason # Max delay is full duration minus two steps due to above reason
if phase_kwargs.get('learn_delay'): if phase_kwargs.get('learn_delay') and phase_kwargs.get('delay_bound') is None:
phase_kwargs["delay_bound"] = [0, black_box_kwargs['duration'] - env.dt * 2] phase_kwargs["delay_bound"] = [0, black_box_kwargs['duration'] - env.dt * 2]
phase_gen = get_phase_generator(**phase_kwargs) phase_gen = get_phase_generator(**phase_kwargs)