table tennis 4D replanning works git add .git add .

This commit is contained in:
Hongyi Zhou 2022-11-09 10:42:36 +01:00
parent b9c2348855
commit 99a514026f
5 changed files with 89 additions and 92 deletions

View File

@ -70,9 +70,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32)
self.action_space.low[0] = 0.5
self.action_space.low[0] = 0.8
self.action_space.high[0] = 1.5
self.action_space.low[1] = 0.02
self.action_space.low[1] = 0.05
self.action_space.high[1] = 0.15
self.observation_space = self._get_observation_space()
@ -99,8 +99,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def get_trajectory(self, action: np.ndarray) -> Tuple:
# duration = self.duration
# duration = self.duration - self.current_traj_steps * self.dt
duration = 2.
duration = self.duration - self.current_traj_steps * self.dt
if self.learn_sub_trajectories:
duration = None
# reset with every new call as we need to set all arguments, such as tau, delay, again.
@ -122,7 +121,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
# self.traj_gen.set_duration(duration, self.dt)
self.traj_gen.set_duration(self.tau_first_prediction, self.dt)
self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
position = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel())
@ -164,12 +163,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
time_valid = self.env.check_time_validity(action)
if time_valid:
if self.plan_counts == 0:
self.tau_first_prediction = action[0]
# time_valid = self.env.check_time_validity(action)
#
# if time_valid:
## tricky part, only use weights basis
# basis_weights = action.reshape(7, -1)
@ -256,9 +252,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos
else:
obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
return self.observation(obs), trajectory_return, done, infos
# else:
# obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
# return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout.
This only needs to be called once"""

View File

@ -72,6 +72,7 @@ DEFAULT_BB_DICT_ProDMP = {
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp',
'duration': 2.0,
'weights_scale': 1.0,
},
"phase_generator_kwargs": {
@ -254,7 +255,7 @@ for ctxt_dim in [2, 4]:
register(
id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=500,
max_episode_steps=350,
kwargs={
"ctxt_dim": ctxt_dim,
'frame_skip': 4

View File

@ -51,8 +51,8 @@ class MPWrapper(RawInterfaceWrapper):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return True
return False
return True
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[np.ndarray, float, bool, dict]:

View File

@ -127,8 +127,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self):
self._steps = 0
self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
self._goal_pos = self._generate_goal_pos(random=False)
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
self._goal_pos = self._generate_goal_pos(random=True)
self.data.joint("tar_x").qpos = self._init_ball_state[0]
self.data.joint("tar_y").qpos = self._init_ball_state[1]
self.data.joint("tar_z").qpos = self._init_ball_state[2]

View File

@ -161,7 +161,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__':
render = False
render = True
# DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)