table tennis 4D replanning works git add .git add .

This commit is contained in:
Hongyi Zhou 2022-11-09 10:42:36 +01:00
parent b9c2348855
commit 99a514026f
5 changed files with 89 additions and 92 deletions

View File

@ -70,9 +70,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7) # tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7) # tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32) # self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32)
self.action_space.low[0] = 0.5 self.action_space.low[0] = 0.8
self.action_space.high[0] = 1.5 self.action_space.high[0] = 1.5
self.action_space.low[1] = 0.02 self.action_space.low[1] = 0.05
self.action_space.high[1] = 0.15 self.action_space.high[1] = 0.15
self.observation_space = self._get_observation_space() self.observation_space = self._get_observation_space()
@ -99,8 +99,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def get_trajectory(self, action: np.ndarray) -> Tuple: def get_trajectory(self, action: np.ndarray) -> Tuple:
# duration = self.duration # duration = self.duration
# duration = self.duration - self.current_traj_steps * self.dt duration = self.duration - self.current_traj_steps * self.dt
duration = 2.
if self.learn_sub_trajectories: if self.learn_sub_trajectories:
duration = None duration = None
# reset with every new call as we need to set all arguments, such as tau, delay, again. # reset with every new call as we need to set all arguments, such as tau, delay, again.
@ -122,7 +121,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32) self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel) self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
# self.traj_gen.set_duration(duration, self.dt) # self.traj_gen.set_duration(duration, self.dt)
self.traj_gen.set_duration(self.tau_first_prediction, self.dt) self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
position = get_numpy(self.traj_gen.get_traj_pos()) position = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel()) velocity = get_numpy(self.traj_gen.get_traj_vel())
@ -164,12 +163,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step""" """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
time_valid = self.env.check_time_validity(action) # time_valid = self.env.check_time_validity(action)
#
if time_valid: # if time_valid:
if self.plan_counts == 0:
self.tau_first_prediction = action[0]
## tricky part, only use weights basis ## tricky part, only use weights basis
# basis_weights = action.reshape(7, -1) # basis_weights = action.reshape(7, -1)
@ -256,9 +252,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
else: else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
else: # else:
obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action) # obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
return self.observation(obs), trajectory_return, done, infos # return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.
This only needs to be called once""" This only needs to be called once"""

View File

@ -72,6 +72,7 @@ DEFAULT_BB_DICT_ProDMP = {
"wrappers": [], "wrappers": [],
"trajectory_generator_kwargs": { "trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp', 'trajectory_generator_type': 'prodmp',
'duration': 2.0,
'weights_scale': 1.0, 'weights_scale': 1.0,
}, },
"phase_generator_kwargs": { "phase_generator_kwargs": {
@ -254,7 +255,7 @@ for ctxt_dim in [2, 4]:
register( register(
id='TableTennis{}D-v0'.format(ctxt_dim), id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv', entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=500, max_episode_steps=350,
kwargs={ kwargs={
"ctxt_dim": ctxt_dim, "ctxt_dim": ctxt_dim,
'frame_skip': 4 'frame_skip': 4

View File

@ -51,8 +51,8 @@ class MPWrapper(RawInterfaceWrapper):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0] or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return True
return False return False
return True
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[np.ndarray, float, bool, dict]: -> Tuple[np.ndarray, float, bool, dict]:

View File

@ -127,8 +127,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self._steps = 0 self._steps = 0
self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False) self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
self._goal_pos = self._generate_goal_pos(random=False) self._goal_pos = self._generate_goal_pos(random=True)
self.data.joint("tar_x").qpos = self._init_ball_state[0] self.data.joint("tar_x").qpos = self._init_ball_state[0]
self.data.joint("tar_y").qpos = self._init_ball_state[1] self.data.joint("tar_y").qpos = self._init_ball_state[1]
self.data.joint("tar_z").qpos = self._init_ball_state[2] self.data.joint("tar_z").qpos = self._init_ball_state[2]

View File

@ -161,7 +161,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = False render = True
# DMP # DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)