table tennis 4D replanning works git add .git add .

This commit is contained in:
Hongyi Zhou 2022-11-09 10:42:36 +01:00
parent b9c2348855
commit 99a514026f
5 changed files with 89 additions and 92 deletions

View File

@ -70,9 +70,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7) # tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7) # tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32) # self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32)
self.action_space.low[0] = 0.5 self.action_space.low[0] = 0.8
self.action_space.high[0] = 1.5 self.action_space.high[0] = 1.5
self.action_space.low[1] = 0.02 self.action_space.low[1] = 0.05
self.action_space.high[1] = 0.15 self.action_space.high[1] = 0.15
self.observation_space = self._get_observation_space() self.observation_space = self._get_observation_space()
@ -99,8 +99,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def get_trajectory(self, action: np.ndarray) -> Tuple: def get_trajectory(self, action: np.ndarray) -> Tuple:
# duration = self.duration # duration = self.duration
# duration = self.duration - self.current_traj_steps * self.dt duration = self.duration - self.current_traj_steps * self.dt
duration = 2.
if self.learn_sub_trajectories: if self.learn_sub_trajectories:
duration = None duration = None
# reset with every new call as we need to set all arguments, such as tau, delay, again. # reset with every new call as we need to set all arguments, such as tau, delay, again.
@ -122,7 +121,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32) self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel) self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
# self.traj_gen.set_duration(duration, self.dt) # self.traj_gen.set_duration(duration, self.dt)
self.traj_gen.set_duration(self.tau_first_prediction, self.dt) self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
position = get_numpy(self.traj_gen.get_traj_pos()) position = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel()) velocity = get_numpy(self.traj_gen.get_traj_vel())
@ -164,101 +163,98 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step""" """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
time_valid = self.env.check_time_validity(action) # time_valid = self.env.check_time_validity(action)
#
if time_valid: # if time_valid:
if self.plan_counts == 0:
self.tau_first_prediction = action[0]
## tricky part, only use weights basis ## tricky part, only use weights basis
# basis_weights = action.reshape(7, -1) # basis_weights = action.reshape(7, -1)
# goal_weights = np.zeros((7, 1)) # goal_weights = np.zeros((7, 1))
# action = np.concatenate((basis_weights, goal_weights), axis=1).flatten() # action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
# TODO remove this part, right now only needed for beer pong # TODO remove this part, right now only needed for beer pong
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen) # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
position, velocity = self.get_trajectory(action) position, velocity = self.get_trajectory(action)
traj_is_valid = self.env.episode_callback(action, position, velocity) traj_is_valid = self.env.episode_callback(action, position, velocity)
trajectory_length = len(position) trajectory_length = len(position)
rewards = np.zeros(shape=(trajectory_length,)) rewards = np.zeros(shape=(trajectory_length,))
if self.verbose >= 2: if self.verbose >= 2:
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape) actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape, observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
dtype=self.env.observation_space.dtype) dtype=self.env.observation_space.dtype)
infos = dict() infos = dict()
done = False done = False
if self.verbose >= 2: if self.verbose >= 2:
desired_pos_traj = [] desired_pos_traj = []
desired_vel_traj = [] desired_vel_traj = []
pos_traj = [] pos_traj = []
vel_traj = [] vel_traj = []
if traj_is_valid: if traj_is_valid:
self.plan_counts += 1 self.plan_counts += 1
for t, (pos, vel) in enumerate(zip(position, velocity)): for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
obs, c_reward, done, info = self.env.step(c_action) obs, c_reward, done, info = self.env.step(c_action)
rewards[t] = c_reward rewards[t] = c_reward
if self.verbose >= 2:
actions[t, :] = c_action
observations[t, :] = obs
for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length)
elems[t] = v
infos[k] = elems
if self.verbose >= 2:
desired_pos_traj.append(pos)
desired_vel_traj.append(vel)
pos_traj.append(self.current_pos)
vel_traj.append(self.current_vel)
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
# if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
# continue
self.condition_pos = pos if self.desired_conditioning else self.current_pos
self.condition_vel = vel if self.desired_conditioning else self.current_vel
break
infos.update({k: v[:t+1] for k, v in infos.items()})
self.current_traj_steps += t + 1
if self.verbose >= 2: if self.verbose >= 2:
infos['desired_pos'] = position[:t+1] actions[t, :] = c_action
infos['desired_vel'] = velocity[:t+1] observations[t, :] = obs
infos['current_pos'] = self.current_pos
infos['current_vel'] = self.current_vel
infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1]
infos['desired_pos_traj'] = np.array(desired_pos_traj)
infos['desired_vel_traj'] = np.array(desired_vel_traj)
infos['pos_traj'] = np.array(pos_traj)
infos['vel_traj'] = np.array(vel_traj)
infos['trajectory_length'] = t + 1 for k, v in info.items():
trajectory_return = self.reward_aggregation(rewards[:t + 1]) elems = infos.get(k, [None] * trajectory_length)
return self.observation(obs), trajectory_return, done, infos elems[t] = v
else: infos[k] = elems
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos if self.verbose >= 2:
desired_pos_traj.append(pos)
desired_vel_traj.append(vel)
pos_traj.append(self.current_pos)
vel_traj.append(self.current_vel)
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
# if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
# continue
self.condition_pos = pos if self.desired_conditioning else self.current_pos
self.condition_vel = vel if self.desired_conditioning else self.current_vel
break
infos.update({k: v[:t+1] for k, v in infos.items()})
self.current_traj_steps += t + 1
if self.verbose >= 2:
infos['desired_pos'] = position[:t+1]
infos['desired_vel'] = velocity[:t+1]
infos['current_pos'] = self.current_pos
infos['current_vel'] = self.current_vel
infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1]
infos['desired_pos_traj'] = np.array(desired_pos_traj)
infos['desired_vel_traj'] = np.array(desired_vel_traj)
infos['pos_traj'] = np.array(pos_traj)
infos['vel_traj'] = np.array(vel_traj)
infos['trajectory_length'] = t + 1
trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, done, infos
else: else:
obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action) obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
# else:
# obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
# return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.
This only needs to be called once""" This only needs to be called once"""

View File

@ -72,6 +72,7 @@ DEFAULT_BB_DICT_ProDMP = {
"wrappers": [], "wrappers": [],
"trajectory_generator_kwargs": { "trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp', 'trajectory_generator_type': 'prodmp',
'duration': 2.0,
'weights_scale': 1.0, 'weights_scale': 1.0,
}, },
"phase_generator_kwargs": { "phase_generator_kwargs": {
@ -254,7 +255,7 @@ for ctxt_dim in [2, 4]:
register( register(
id='TableTennis{}D-v0'.format(ctxt_dim), id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv', entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=500, max_episode_steps=350,
kwargs={ kwargs={
"ctxt_dim": ctxt_dim, "ctxt_dim": ctxt_dim,
'frame_skip': 4 'frame_skip': 4

View File

@ -51,8 +51,8 @@ class MPWrapper(RawInterfaceWrapper):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0] or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return True return False
return False return True
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[np.ndarray, float, bool, dict]: -> Tuple[np.ndarray, float, bool, dict]:

View File

@ -127,8 +127,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self._steps = 0 self._steps = 0
self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False) self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
self._goal_pos = self._generate_goal_pos(random=False) self._goal_pos = self._generate_goal_pos(random=True)
self.data.joint("tar_x").qpos = self._init_ball_state[0] self.data.joint("tar_x").qpos = self._init_ball_state[0]
self.data.joint("tar_y").qpos = self._init_ball_state[1] self.data.joint("tar_y").qpos = self._init_ball_state[1]
self.data.joint("tar_z").qpos = self._init_ball_state[2] self.data.joint("tar_z").qpos = self._init_ball_state[2]

View File

@ -161,7 +161,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = False render = True
# DMP # DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)