add invalid trajectory callback & invalid traj return & register all 3 variantes of table tennis tasks
This commit is contained in:
parent
28aa430fd2
commit
f376772c22
@ -63,14 +63,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
||||||
self.action_space = self._get_action_space()
|
self.action_space = self._get_action_space()
|
||||||
|
|
||||||
# no goal learning
|
|
||||||
# tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7)
|
|
||||||
# tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7)
|
|
||||||
# self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32)
|
|
||||||
self.action_space.low[0] = 0.8
|
|
||||||
self.action_space.high[0] = 1.5
|
|
||||||
self.action_space.low[1] = 0.05
|
|
||||||
self.action_space.high[1] = 0.15
|
|
||||||
self.observation_space = self._get_observation_space()
|
self.observation_space = self._get_observation_space()
|
||||||
|
|
||||||
# rendering
|
# rendering
|
||||||
@ -93,8 +85,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
return observation.astype(self.observation_space.dtype)
|
return observation.astype(self.observation_space.dtype)
|
||||||
|
|
||||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||||
# duration = self.duration
|
duration = self.duration
|
||||||
duration = self.duration - self.current_traj_steps * self.dt
|
# duration = self.duration - self.current_traj_steps * self.dt
|
||||||
if self.learn_sub_trajectories:
|
if self.learn_sub_trajectories:
|
||||||
duration = None
|
duration = None
|
||||||
# reset with every new call as we need to set all arguments, such as tau, delay, again.
|
# reset with every new call as we need to set all arguments, such as tau, delay, again.
|
||||||
@ -157,8 +149,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
# TODO remove this part, right now only needed for beer pong
|
# TODO remove this part, right now only needed for beer pong
|
||||||
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
|
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
|
||||||
position, velocity = self.get_trajectory(action)
|
position, velocity = self.get_trajectory(action)
|
||||||
traj_is_valid = self.env.episode_callback(action, position, velocity)
|
traj_is_valid = self.env.preprocessing_and_validity_callback(action, position, velocity)
|
||||||
|
# insert validation here
|
||||||
trajectory_length = len(position)
|
trajectory_length = len(position)
|
||||||
rewards = np.zeros(shape=(trajectory_length,))
|
rewards = np.zeros(shape=(trajectory_length,))
|
||||||
if self.verbose >= 2:
|
if self.verbose >= 2:
|
||||||
@ -169,7 +161,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
if traj_is_valid:
|
if not traj_is_valid:
|
||||||
|
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||||
|
self.return_context_observation)
|
||||||
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
|
else:
|
||||||
self.plan_steps += 1
|
self.plan_steps += 1
|
||||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||||
current_pos = self.current_pos
|
current_pos = self.current_pos
|
||||||
@ -215,10 +211,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos['trajectory_length'] = t + 1
|
infos['trajectory_length'] = t + 1
|
||||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||||
return self.observation(obs), trajectory_return, done, infos
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
else:
|
|
||||||
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
|
|
||||||
self.return_context_observation)
|
|
||||||
return self.observation(obs), trajectory_return, done, infos
|
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
"""Only set render options here, such that they can be used during the rollout.
|
||||||
|
@ -52,6 +52,19 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
"""
|
"""
|
||||||
return self.env.dt
|
return self.env.dt
|
||||||
|
|
||||||
|
def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
||||||
|
-> Tuple[bool, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Used to preprocess the action and check if the desired trajectory is valid.
|
||||||
|
"""
|
||||||
|
return True, pos_traj, vel_traj
|
||||||
|
|
||||||
|
def set_episode_arguments(self, action, pos_traj, vel_traj):
|
||||||
|
"""
|
||||||
|
Used to set the arguments for env that valid for the whole episode
|
||||||
|
"""
|
||||||
|
return pos_traj, vel_traj
|
||||||
|
|
||||||
def episode_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.array) -> Tuple[bool]:
|
def episode_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.array) -> Tuple[bool]:
|
||||||
"""
|
"""
|
||||||
Used to extract the parameters for the movement primitive and other parameters from an action array which might
|
Used to extract the parameters for the movement primitive and other parameters from an action array which might
|
||||||
@ -68,7 +81,6 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
|
|
||||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
"""
|
"""
|
||||||
Used to return a fake return from the environment if the desired trajectory is invalid.
|
Used to return a artificial return from the env if the desired trajectory is invalid.
|
||||||
"""
|
"""
|
||||||
obs = np.zeros(1)
|
return np.zeros(1), 0, True, {}
|
||||||
return obs, 0, True, {}
|
|
@ -18,6 +18,8 @@ from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
|||||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||||
from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \
|
from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \
|
||||||
BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING
|
BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING
|
||||||
|
from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching, \
|
||||||
|
MAX_EPISODE_STEPS_TABLE_TENNIS
|
||||||
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
@ -248,17 +250,28 @@ for ctxt_dim in [2, 4]:
|
|||||||
register(
|
register(
|
||||||
id='TableTennis{}D-v0'.format(ctxt_dim),
|
id='TableTennis{}D-v0'.format(ctxt_dim),
|
||||||
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
|
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
|
||||||
max_episode_steps=350,
|
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
|
||||||
kwargs={
|
kwargs={
|
||||||
"ctxt_dim": ctxt_dim,
|
"ctxt_dim": ctxt_dim,
|
||||||
'frame_skip': 4,
|
'frame_skip': 4,
|
||||||
'enable_wind': False,
|
'goal_switching_step': None,
|
||||||
'enable_switching_goal': False,
|
'enable_artificial_wind': False,
|
||||||
'enable_air': False,
|
|
||||||
'enable_artifical_wind': False,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='TableTennisWind-v0',
|
||||||
|
entry_point='fancy_gym.envs.mujoco:TableTennisWind',
|
||||||
|
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='TableTennisGoalSwitching-v0',
|
||||||
|
entry_point='fancy_gym.envs.mujoco:TableTennisGoalSwitching',
|
||||||
|
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# movement Primitive Environments
|
# movement Primitive Environments
|
||||||
|
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
@ -529,17 +542,22 @@ for _v in _versions:
|
|||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
## Table Tennis
|
## Table Tennis
|
||||||
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0']
|
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0']
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
_name = _v.split("-")
|
_name = _v.split("-")
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
kwargs_dict_tt_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_tt_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.MPWrapper)
|
if _v == 'TableTennisWind-v0':
|
||||||
|
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
|
||||||
|
else:
|
||||||
|
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.TT_MPWrapper)
|
||||||
kwargs_dict_tt_promp['name'] = _v
|
kwargs_dict_tt_promp['name'] = _v
|
||||||
kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
||||||
kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||||
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True
|
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True
|
||||||
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
|
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
|
||||||
|
kwargs_dict_tt_promp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
|
||||||
|
kwargs_dict_tt_promp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15]
|
||||||
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
|
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
|
||||||
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
|
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
|
||||||
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1
|
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1
|
||||||
@ -556,7 +574,10 @@ for _v in _versions:
|
|||||||
_name = _v.split("-")
|
_name = _v.split("-")
|
||||||
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
||||||
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper)
|
if _v == 'TableTennisWind-v0':
|
||||||
|
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
|
||||||
|
else:
|
||||||
|
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TT_MPWrapper)
|
||||||
kwargs_dict_tt_prodmp['name'] = _v
|
kwargs_dict_tt_prodmp['name'] = _v
|
||||||
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
||||||
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||||
@ -564,12 +585,13 @@ for _v in _versions:
|
|||||||
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
|
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
|
||||||
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False
|
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False
|
||||||
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
|
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
|
||||||
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
|
||||||
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15]
|
||||||
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True
|
||||||
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
|
||||||
kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
|
kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
|
||||||
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
|
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
|
||||||
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
||||||
#kwargs_dict_tt_prodmp['basis_generator_kwargs']['pre_compute_length_factor'] = 5
|
|
||||||
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||||
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
|
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
|
||||||
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0
|
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0
|
||||||
|
@ -8,4 +8,4 @@ from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
|
|||||||
from .reacher.reacher import ReacherEnv
|
from .reacher.reacher import ReacherEnv
|
||||||
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
|
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
|
||||||
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
|
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
|
||||||
from .table_tennis.table_tennis_env import TableTennisEnv
|
from .table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching
|
||||||
|
@ -1 +1 @@
|
|||||||
from .mp_wrapper import MPWrapper
|
from .mp_wrapper import TT_MPWrapper, TTVelObs_MPWrapper
|
||||||
|
@ -6,7 +6,7 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
|||||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class TT_MPWrapper(RawInterfaceWrapper):
|
||||||
|
|
||||||
# Random x goal + random init pos
|
# Random x goal + random init pos
|
||||||
@property
|
@property
|
||||||
@ -29,48 +29,26 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.data.qvel[:7].copy()
|
return self.data.qvel[:7].copy()
|
||||||
|
|
||||||
def check_time_validity(self, action):
|
def preprocessing_and_validity_callback(self, action, pos_traj, vel_traj):
|
||||||
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
|
return self.check_traj_validity(action, pos_traj, vel_traj)
|
||||||
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
|
|
||||||
|
|
||||||
def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
|
def set_episode_arguments(self, action, pos_traj, vel_traj):
|
||||||
-> Tuple[np.ndarray, float, bool, dict]:
|
return pos_traj, vel_traj
|
||||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
|
||||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
|
||||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
|
|
||||||
obs = np.concatenate([self.get_obs(), np.array([0])])
|
|
||||||
return obs, -invalid_penalty, True, {
|
|
||||||
"hit_ball": [False],
|
|
||||||
"ball_returned_success": [False],
|
|
||||||
"land_dist_error": [10.],
|
|
||||||
"is_success": [False],
|
|
||||||
'trajectory_length': 1,
|
|
||||||
"num_steps": [1]
|
|
||||||
}
|
|
||||||
|
|
||||||
def episode_callback(self, action, pos_traj, vel_traj):
|
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||||
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
|
||||||
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
|
||||||
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
|
||||||
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
|
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
return self.get_invalid_traj_step_return(action, pos_traj, vel_traj, return_contextual_obs)
|
||||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
|
||||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
class TTVelObs_MPWrapper(TT_MPWrapper):
|
||||||
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
|
||||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
@property
|
||||||
violate_high_bound_error + violate_low_bound_error
|
def context_mask(self):
|
||||||
obs = np.concatenate([self.get_obs(), np.array([0])])
|
return np.hstack([
|
||||||
if return_contextual_obs:
|
[False] * 7, # joints position
|
||||||
obs = self.get_obs()
|
[False] * 7, # joints velocity
|
||||||
return obs, -invalid_penalty, True, {
|
[True] * 2, # position ball x, y
|
||||||
"hit_ball": [False],
|
[False] * 1, # position ball z
|
||||||
"ball_returned_success": [False],
|
[True] * 3, # velocity ball x, y, z
|
||||||
"land_dist_error": [10.],
|
[True] * 2, # target landing position
|
||||||
"is_success": [False],
|
# [True] * 1, # time
|
||||||
'trajectory_length': 1,
|
])
|
||||||
"num_steps": [1]
|
|
||||||
}
|
|
@ -5,6 +5,7 @@ from gym import utils, spaces
|
|||||||
from gym.envs.mujoco import MujocoEnv
|
from gym.envs.mujoco import MujocoEnv
|
||||||
|
|
||||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
|
||||||
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||||
|
|
||||||
import mujoco
|
import mujoco
|
||||||
|
|
||||||
@ -21,13 +22,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
"""
|
"""
|
||||||
7 DoF table tennis environment
|
7 DoF table tennis environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
||||||
enable_switching_goal: bool = False,
|
goal_switching_step: int = None,
|
||||||
enable_wind: bool = False,
|
enable_artificial_wind: bool = False):
|
||||||
enable_artifical_wind: bool = False,
|
|
||||||
enable_magnus: bool = False,
|
|
||||||
enable_air: bool = False):
|
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
@ -47,12 +44,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self._ball_traj = []
|
self._ball_traj = []
|
||||||
self._racket_traj = []
|
self._racket_traj = []
|
||||||
|
|
||||||
|
self._goal_switching_step = goal_switching_step
|
||||||
|
|
||||||
self._enable_goal_switching = enable_switching_goal
|
self._enable_artificial_wind = enable_artificial_wind
|
||||||
|
|
||||||
self._enable_artifical_wind = enable_artifical_wind
|
self._artificial_force = 0.
|
||||||
|
|
||||||
self._artifical_force = 0.
|
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
@ -62,20 +58,13 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||||
elif ctxt_dim == 4:
|
elif ctxt_dim == 4:
|
||||||
self.context_bounds = CONTEXT_BOUNDS_4DIMS
|
self.context_bounds = CONTEXT_BOUNDS_4DIMS
|
||||||
if self._enable_goal_switching:
|
if self._goal_switching_step is not None:
|
||||||
self.context_bounds = CONTEXT_BOUNDS_SWICHING
|
self.context_bounds = CONTEXT_BOUNDS_SWICHING
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
||||||
|
|
||||||
# complex dynamics settings
|
|
||||||
if enable_air:
|
|
||||||
self.model.opt.density = 1.225
|
|
||||||
self.model.opt.viscosity = 2.27e-5
|
|
||||||
|
|
||||||
self._enable_wind = enable_wind
|
|
||||||
self._enable_magnus = enable_magnus
|
|
||||||
self._wind_vel = np.zeros(3)
|
self._wind_vel = np.zeros(3)
|
||||||
|
|
||||||
def _set_ids(self):
|
def _set_ids(self):
|
||||||
@ -92,9 +81,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
unstable_simulation = False
|
unstable_simulation = False
|
||||||
|
|
||||||
|
if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5:
|
||||||
if self._enable_goal_switching:
|
|
||||||
if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5:
|
|
||||||
new_goal_pos = self._generate_goal_pos(random=True)
|
new_goal_pos = self._generate_goal_pos(random=True)
|
||||||
new_goal_pos[1] = -new_goal_pos[1]
|
new_goal_pos[1] = -new_goal_pos[1]
|
||||||
self._goal_pos = new_goal_pos
|
self._goal_pos = new_goal_pos
|
||||||
@ -102,8 +89,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
for _ in range(self.frame_skip):
|
for _ in range(self.frame_skip):
|
||||||
if self._enable_artifical_wind:
|
if self._enable_artificial_wind:
|
||||||
self.data.qfrc_applied[-2] = self._artifical_force
|
self.data.qfrc_applied[-2] = self._artificial_force
|
||||||
try:
|
try:
|
||||||
self.do_simulation(action, 1)
|
self.do_simulation(action, 1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -163,7 +150,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
||||||
# self._init_ball_state[2] = 1.85
|
|
||||||
self._goal_pos = self._generate_goal_pos(random=True)
|
self._goal_pos = self._generate_goal_pos(random=True)
|
||||||
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
||||||
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
||||||
@ -172,19 +158,16 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.data.joint("tar_y").qvel = self._init_ball_state[4]
|
self.data.joint("tar_y").qvel = self._init_ball_state[4]
|
||||||
self.data.joint("tar_z").qvel = self._init_ball_state[5]
|
self.data.joint("tar_z").qvel = self._init_ball_state[5]
|
||||||
|
|
||||||
if self._enable_artifical_wind:
|
if self._enable_artificial_wind:
|
||||||
self._artifical_force = self.np_random.uniform(low=-0.1, high=0.1)
|
self._artificial_force = self.np_random.uniform(low=-0.1, high=0.1)
|
||||||
|
|
||||||
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
||||||
|
|
||||||
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
|
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
|
||||||
|
self.data.qvel[:7] = np.zeros(7)
|
||||||
|
|
||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
if self._enable_wind:
|
|
||||||
self._wind_vel[1] = self.np_random.uniform(low=-10, high=10, size=1)
|
|
||||||
self.model.opt.wind[:3] = self._wind_vel
|
|
||||||
|
|
||||||
self._hit_ball = False
|
self._hit_ball = False
|
||||||
self._ball_land_on_table = False
|
self._ball_land_on_table = False
|
||||||
self._ball_contact_after_hit = False
|
self._ball_contact_after_hit = False
|
||||||
@ -208,10 +191,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.data.joint("tar_x").qpos.copy(),
|
self.data.joint("tar_x").qpos.copy(),
|
||||||
self.data.joint("tar_y").qpos.copy(),
|
self.data.joint("tar_y").qpos.copy(),
|
||||||
self.data.joint("tar_z").qpos.copy(),
|
self.data.joint("tar_z").qpos.copy(),
|
||||||
#self.data.joint("tar_x").qvel.copy(),
|
# self.data.joint("tar_x").qvel.copy(),
|
||||||
#self.data.joint("tar_y").qvel.copy(),
|
# self.data.joint("tar_y").qvel.copy(),
|
||||||
#self.data.joint("tar_z").qvel.copy(),
|
# self.data.joint("tar_z").qvel.copy(),
|
||||||
# self.data.body("target_ball").xvel.copy(),
|
|
||||||
self._goal_pos.copy(),
|
self._goal_pos.copy(),
|
||||||
])
|
])
|
||||||
return obs
|
return obs
|
||||||
@ -252,56 +234,54 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||||
return init_ball_state
|
return init_ball_state
|
||||||
|
|
||||||
def plot_ball_traj(x_traj, y_traj, z_traj):
|
def _get_traj_invalid_reward(self, action, pos_traj, vel_traj):
|
||||||
import matplotlib.pyplot as plt
|
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||||
fig = plt.figure()
|
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||||
ax.plot(x_traj, y_traj, z_traj)
|
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
||||||
plt.show()
|
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||||
|
violate_high_bound_error + violate_low_bound_error
|
||||||
|
return -invalid_penalty
|
||||||
|
|
||||||
def plot_ball_traj_2d(x_traj, y_traj):
|
def get_invalid_traj_step_return(self, action, pos_traj, vel_traj, contextual_obs):
|
||||||
import matplotlib.pyplot as plt
|
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||||
fig = plt.figure()
|
penalty = self._get_traj_invalid_reward(action, pos_traj, vel_traj)
|
||||||
ax = fig.add_subplot(111)
|
return obs, penalty, True, {
|
||||||
ax.plot(x_traj, y_traj)
|
"hit_ball": [False],
|
||||||
plt.show()
|
"ball_return_success": [False],
|
||||||
|
"land_dist_error": [False],
|
||||||
|
"trajectory_length": 1,
|
||||||
|
"num_steps": [1],
|
||||||
|
}
|
||||||
|
|
||||||
def plot_compare_trajs(traj1, traj2, title):
|
@staticmethod
|
||||||
import matplotlib.pyplot as plt
|
def check_traj_validity(action, pos_traj, vel_traj):
|
||||||
fig = plt.figure()
|
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||||
ax = fig.add_subplot(111)
|
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
||||||
ax.plot(traj1, color='r', label='traj1')
|
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
||||||
ax.plot(traj2, color='b', label='traj2')
|
return False, pos_traj, vel_traj
|
||||||
ax.set_title(title)
|
return True, pos_traj, vel_traj
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
env_air = TableTennisEnv(enable_air=False, enable_wind=False, enable_artifical_wind=True)
|
class TableTennisWind(TableTennisEnv):
|
||||||
env_no_air = TableTennisEnv(enable_air=False, enable_wind=False)
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
|
||||||
for _ in range(10):
|
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
|
||||||
obs1 = env_air.reset()
|
|
||||||
obs2 = env_no_air.reset()
|
def _get_obs(self):
|
||||||
# obs2 = env_with_air.reset()
|
obs = np.concatenate([
|
||||||
air_x_pos = []
|
self.data.qpos.flat[:7].copy(),
|
||||||
no_air_x_pos = []
|
self.data.qvel.flat[:7].copy(),
|
||||||
# y_pos = []
|
self.data.joint("tar_x").qpos.copy(),
|
||||||
# z_pos = []
|
self.data.joint("tar_y").qpos.copy(),
|
||||||
# x_vel = []
|
self.data.joint("tar_z").qpos.copy(),
|
||||||
# y_vel = []
|
self.data.joint("tar_x").qvel.copy(),
|
||||||
# z_vel = []
|
self.data.joint("tar_y").qvel.copy(),
|
||||||
for _ in range(2000):
|
self.data.joint("tar_z").qvel.copy(),
|
||||||
env_air.render("human")
|
self._goal_pos.copy(),
|
||||||
obs1, reward1, done1, info1 = env_air.step(np.zeros(7))
|
])
|
||||||
obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7))
|
return obs
|
||||||
# # _, _, _, _ = env_no_air.step(np.zeros(7))
|
|
||||||
air_x_pos.append(env_air.data.joint("tar_z").qpos[0])
|
|
||||||
no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0])
|
class TableTennisGoalSwitching(TableTennisEnv):
|
||||||
# # z_pos.append(env.data.joint("tar_z").qpos[0])
|
def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99):
|
||||||
# # x_vel.append(env.data.joint("tar_x").qvel[0])
|
super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step)
|
||||||
# # y_vel.append(env.data.joint("tar_y").qvel[0])
|
|
||||||
# # z_vel.append(env.data.joint("tar_z").qvel[0])
|
|
||||||
# # print(reward)
|
|
||||||
if info1["num_steps"] == 150:
|
|
||||||
plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out wind")
|
|
||||||
break
|
|
||||||
|
@ -155,7 +155,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = False
|
||||||
# DMP
|
# DMP
|
||||||
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
@ -163,10 +163,15 @@ if __name__ == '__main__':
|
|||||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||||
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||||
# example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
|
# example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
|
||||||
|
# example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render)
|
||||||
|
# example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render)
|
||||||
|
|
||||||
# ProDMP
|
# ProDMP
|
||||||
# example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
|
# example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
|
||||||
example_mp("TableTennis4DProDMP-v0", seed=10, iterations=20, render=render)
|
example_mp("TableTennis4DProDMP-v0", seed=10, iterations=2000, render=render)
|
||||||
|
# example_mp("TableTennisWindProDMP-v0", seed=10, iterations=20, render=render)
|
||||||
|
# example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=20, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
|
@ -168,11 +168,11 @@ def make_bb(
|
|||||||
|
|
||||||
# set tau bounds to minimum of two env steps otherwise computing the velocity is not possible.
|
# set tau bounds to minimum of two env steps otherwise computing the velocity is not possible.
|
||||||
# maximum is full duration of one episode.
|
# maximum is full duration of one episode.
|
||||||
if phase_kwargs.get('learn_tau'):
|
if phase_kwargs.get('learn_tau') and phase_kwargs.get('tau_bound') is None:
|
||||||
phase_kwargs["tau_bound"] = [env.dt * 2, black_box_kwargs['duration']]
|
phase_kwargs["tau_bound"] = [env.dt * 2, black_box_kwargs['duration']]
|
||||||
|
|
||||||
# Max delay is full duration minus two steps due to above reason
|
# Max delay is full duration minus two steps due to above reason
|
||||||
if phase_kwargs.get('learn_delay'):
|
if phase_kwargs.get('learn_delay') and phase_kwargs.get('delay_bound') is None:
|
||||||
phase_kwargs["delay_bound"] = [0, black_box_kwargs['duration'] - env.dt * 2]
|
phase_kwargs["delay_bound"] = [0, black_box_kwargs['duration'] - env.dt * 2]
|
||||||
|
|
||||||
phase_gen = get_phase_generator(**phase_kwargs)
|
phase_gen = get_phase_generator(**phase_kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user