check time validity before pos validity

This commit is contained in:
Hongyi Zhou 2022-11-08 13:57:32 +01:00
parent d384e6e764
commit b9c2348855
3 changed files with 109 additions and 83 deletions

View File

@ -164,95 +164,101 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step""" """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
time_valid = self.env.check_time_validity(action)
if self.plan_counts == 0: if time_valid:
self.tau_first_prediction = action[0]
## tricky part, only use weights basis if self.plan_counts == 0:
# basis_weights = action.reshape(7, -1) self.tau_first_prediction = action[0]
# goal_weights = np.zeros((7, 1))
# action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
# TODO remove this part, right now only needed for beer pong ## tricky part, only use weights basis
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen) # basis_weights = action.reshape(7, -1)
position, velocity = self.get_trajectory(action) # goal_weights = np.zeros((7, 1))
traj_is_valid = self.env.episode_callback(action, position, velocity) # action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
trajectory_length = len(position) # TODO remove this part, right now only needed for beer pong
rewards = np.zeros(shape=(trajectory_length,)) # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
if self.verbose >= 2: position, velocity = self.get_trajectory(action)
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape) traj_is_valid = self.env.episode_callback(action, position, velocity)
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
dtype=self.env.observation_space.dtype)
infos = dict() trajectory_length = len(position)
done = False rewards = np.zeros(shape=(trajectory_length,))
if self.verbose >= 2:
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
dtype=self.env.observation_space.dtype)
if self.verbose >= 2: infos = dict()
desired_pos_traj = [] done = False
desired_vel_traj = []
pos_traj = []
vel_traj = []
if traj_is_valid:
self.plan_counts += 1
for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
obs, c_reward, done, info = self.env.step(c_action)
rewards[t] = c_reward
if self.verbose >= 2:
actions[t, :] = c_action
observations[t, :] = obs
for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length)
elems[t] = v
infos[k] = elems
if self.verbose >= 2:
desired_pos_traj.append(pos)
desired_vel_traj.append(vel)
pos_traj.append(self.current_pos)
vel_traj.append(self.current_vel)
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
# if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
# continue
self.condition_pos = pos if self.desired_conditioning else self.current_pos
self.condition_vel = vel if self.desired_conditioning else self.current_vel
break
infos.update({k: v[:t+1] for k, v in infos.items()})
self.current_traj_steps += t + 1
if self.verbose >= 2: if self.verbose >= 2:
infos['desired_pos'] = position[:t+1] desired_pos_traj = []
infos['desired_vel'] = velocity[:t+1] desired_vel_traj = []
infos['current_pos'] = self.current_pos pos_traj = []
infos['current_vel'] = self.current_vel vel_traj = []
infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1]
infos['desired_pos_traj'] = np.array(desired_pos_traj)
infos['desired_vel_traj'] = np.array(desired_vel_traj)
infos['pos_traj'] = np.array(pos_traj)
infos['vel_traj'] = np.array(vel_traj)
infos['trajectory_length'] = t + 1 if traj_is_valid:
trajectory_return = self.reward_aggregation(rewards[:t + 1]) self.plan_counts += 1
return self.observation(obs), trajectory_return, done, infos for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
obs, c_reward, done, info = self.env.step(c_action)
rewards[t] = c_reward
if self.verbose >= 2:
actions[t, :] = c_action
observations[t, :] = obs
for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length)
elems[t] = v
infos[k] = elems
if self.verbose >= 2:
desired_pos_traj.append(pos)
desired_vel_traj.append(vel)
pos_traj.append(self.current_pos)
vel_traj.append(self.current_vel)
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
# if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
# continue
self.condition_pos = pos if self.desired_conditioning else self.current_pos
self.condition_vel = vel if self.desired_conditioning else self.current_vel
break
infos.update({k: v[:t+1] for k, v in infos.items()})
self.current_traj_steps += t + 1
if self.verbose >= 2:
infos['desired_pos'] = position[:t+1]
infos['desired_vel'] = velocity[:t+1]
infos['current_pos'] = self.current_pos
infos['current_vel'] = self.current_vel
infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1]
infos['desired_pos_traj'] = np.array(desired_pos_traj)
infos['desired_vel_traj'] = np.array(desired_vel_traj)
infos['pos_traj'] = np.array(pos_traj)
infos['vel_traj'] = np.array(vel_traj)
infos['trajectory_length'] = t + 1
trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, done, infos
else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos
else: else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.
This only needs to be called once""" This only needs to be called once"""

View File

@ -560,7 +560,7 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProDMP-{_name[1]}' _env_id = f'{_name[0]}ProDMP-{_name[1]}'
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper) kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper)
kwargs_dict_tt_prodmp['name'] = _v kwargs_dict_tt_prodmp['name'] = _v
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])

View File

@ -28,12 +28,31 @@ class MPWrapper(RawInterfaceWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.data.qvel[:7].copy() return self.data.qvel[:7].copy()
def check_time_validity(self, action):
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
def time_invalid_traj_callback(self, action) \
-> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
obs = np.concatenate([self.get_obs(), np.array([0])])
return obs, -invalid_penalty, True, {
"hit_ball": [False],
"ball_returned_success": [False],
"land_dist_error": [10.],
"is_success": [False],
'trajectory_length': 1,
"num_steps": [1]
}
def episode_callback(self, action, pos_traj, vel_traj): def episode_callback(self, action, pos_traj, vel_traj):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0] or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return False return True
return True return False
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[np.ndarray, float, bool, dict]: -> Tuple[np.ndarray, float, bool, dict]:
@ -43,7 +62,8 @@ class MPWrapper(RawInterfaceWrapper):
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
violate_high_bound_error + violate_low_bound_error violate_high_bound_error + violate_low_bound_error
return self.get_obs(), -invalid_penalty, True, { obs = np.concatenate([self.get_obs(), np.array([0])])
return obs, -invalid_penalty, True, {
"hit_ball": [False], "hit_ball": [False],
"ball_returned_success": [False], "ball_returned_success": [False],
"land_dist_error": [10.], "land_dist_error": [10.],