fix bug in box pushing IK
This commit is contained in:
parent
e3d36dead0
commit
4a850912be
@ -71,6 +71,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
qpos = self.data.qpos[:7].copy()
|
qpos = self.data.qpos[:7].copy()
|
||||||
qvel = self.data.qvel[:7].copy()
|
qvel = self.data.qvel[:7].copy()
|
||||||
|
|
||||||
|
if (self._steps + 1) % 10 == 0:
|
||||||
|
self.calculateOfflineIK(np.array([0.4, 0.3, 0.14]), np.array([0, 1, 0, 0]))
|
||||||
|
|
||||||
if not unstable_simulation:
|
if not unstable_simulation:
|
||||||
reward = self._get_reward(episode_end, box_pos, box_quat, target_pos, target_quat,
|
reward = self._get_reward(episode_end, box_pos, box_quat, target_pos, target_quat,
|
||||||
rod_tip_pos, rod_quat, qpos, qvel, action)
|
rod_tip_pos, rod_quat, qpos, qvel, action)
|
||||||
@ -93,7 +96,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
# rest box to initial position
|
# rest box to initial position
|
||||||
self.set_state(self.init_qpos_box_pushing, self.init_qvel_box_pushing)
|
self.set_state(self.init_qpos_box_pushing, self.init_qvel_box_pushing)
|
||||||
box_init_pos = np.array([0.4, 0.3, -0.01, 0.0, 0.0, 0.0, 1.0])
|
random_init_pos = self.sample_context()
|
||||||
|
# box_init_pos = np.array([0.4, 0.3, -0.01, 0.0, 0.0, 0.0, 1.0])
|
||||||
|
box_init_pos = random_init_pos
|
||||||
self.data.joint("box_joint").qpos = box_init_pos
|
self.data.joint("box_joint").qpos = box_init_pos
|
||||||
|
|
||||||
# set target position
|
# set target position
|
||||||
@ -219,6 +224,10 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
q_old = q
|
q_old = q
|
||||||
q = q + dt * qd_d
|
q = q + dt * qd_d
|
||||||
q = np.clip(q, q_min, q_max)
|
q = np.clip(q, q_min, q_max)
|
||||||
|
|
||||||
|
self.data.qpos[:7] = q
|
||||||
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
current_cart_pos = self.data.body("tcp").xpos.copy()
|
current_cart_pos = self.data.body("tcp").xpos.copy()
|
||||||
current_cart_quat = self.data.body("tcp").xquat.copy()
|
current_cart_quat = self.data.body("tcp").xquat.copy()
|
||||||
|
|
||||||
@ -230,20 +239,30 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
err = np.hstack((cart_pos_error, cart_quat_error))
|
err = np.hstack((cart_pos_error, cart_quat_error))
|
||||||
err_norm = np.sum(cart_pos_error**2) + np.sum((current_cart_quat - desired_cart_quat)**2)
|
err_norm = np.sum(cart_pos_error**2) + np.sum((current_cart_quat - desired_cart_quat)**2)
|
||||||
|
|
||||||
if err_norm > old_err_norm:
|
if err_norm > old_err_norm:
|
||||||
|
# old_err_norm = err_norm
|
||||||
q = q_old
|
q = q_old
|
||||||
dt = 0.7 * dt
|
dt = 0.7 * dt
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
else:
|
else:
|
||||||
dt = 1.025 * dt
|
dt = 1.025 * dt
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
if err_norm < eps:
|
if err_norm < eps:
|
||||||
|
print("IK converged in {} iterations".format(i))
|
||||||
break
|
break
|
||||||
|
|
||||||
if i > IT_MAX:
|
if i > IT_MAX:
|
||||||
|
print("IK did not converge in {} iterations".format(i))
|
||||||
break
|
break
|
||||||
|
|
||||||
old_err_norm = err_norm
|
old_err_norm = err_norm
|
||||||
|
|
||||||
|
|
||||||
### get Jacobian by mujoco
|
### get Jacobian by mujoco
|
||||||
self.data.qpos[:7] = q
|
self.data.qpos[:7] = q
|
||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
@ -272,7 +291,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
qd_d = w.dot(J.transpose()).dot(qd_d) + qd_null
|
qd_d = w.dot(J.transpose()).dot(qd_d) + qd_null
|
||||||
|
|
||||||
i += 1
|
# i += 1
|
||||||
|
|
||||||
return q
|
return q
|
||||||
|
|
||||||
@ -360,10 +379,10 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
|||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
||||||
env.reset()
|
env.reset()
|
||||||
for i in range(10):
|
for i in range(100):
|
||||||
env.reset()
|
env.reset()
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
env.render("human")
|
env.render("human")
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, done, info = env.step(action)
|
||||||
print("info: {}".format(info))
|
# print("info: {}".format(info))
|
||||||
|
@ -29,24 +29,24 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.data.qvel[:7].copy()
|
return self.data.qvel[:7].copy()
|
||||||
|
|
||||||
def check_time_validity(self, action):
|
# def check_time_validity(self, action):
|
||||||
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
|
# return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
|
||||||
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
|
# and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
|
||||||
|
#
|
||||||
def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
|
# def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
|
||||||
-> Tuple[np.ndarray, float, bool, dict]:
|
# -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
# tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
# delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
|
# invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
|
||||||
obs = np.concatenate([self.get_obs(), np.array([0])])
|
# obs = np.concatenate([self.get_obs(), np.array([0])])
|
||||||
return obs, -invalid_penalty, True, {
|
# return obs, -invalid_penalty, True, {
|
||||||
"hit_ball": [False],
|
# "hit_ball": [False],
|
||||||
"ball_returned_success": [False],
|
# "ball_returned_success": [False],
|
||||||
"land_dist_error": [10.],
|
# "land_dist_error": [10.],
|
||||||
"is_success": [False],
|
# "is_success": [False],
|
||||||
'trajectory_length': 1,
|
# 'trajectory_length': 1,
|
||||||
"num_steps": [1]
|
# "num_steps": [1]
|
||||||
}
|
# }
|
||||||
|
|
||||||
def episode_callback(self, action, pos_traj, vel_traj):
|
def episode_callback(self, action, pos_traj, vel_traj):
|
||||||
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||||
|
22
fancy_gym/examples/table_tennis_reproducebility.py
Normal file
22
fancy_gym/examples/table_tennis_reproducebility.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import fancy_gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
env_1 = fancy_gym.make("TableTennis4DProDMP-v0", seed=0)
|
||||||
|
env_2 = fancy_gym.make("TableTennis4DProDMP-v0", seed=0)
|
||||||
|
|
||||||
|
obs_1 = env_1.reset()
|
||||||
|
obs_2 = env_2.reset()
|
||||||
|
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
||||||
|
for i in range(100000):
|
||||||
|
action = env_1.action_space.sample()
|
||||||
|
obs_1, reward_1, done_1, info_1 = env_1.step(action)
|
||||||
|
obs_2, reward_2, done_2, info_2 = env_2.step(action)
|
||||||
|
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
||||||
|
assert np.all(reward_1 == reward_2), "The rewards should be the same"
|
||||||
|
assert np.all(done_1 == done_2), "The done flags should be the same"
|
||||||
|
for key in info_1:
|
||||||
|
assert np.all(info_1[key] == info_2[key]), f"The info fields: {key} should be the same"
|
||||||
|
if done_1 and done_2:
|
||||||
|
obs_1 = env_1.reset()
|
||||||
|
obs_2 = env_2.reset()
|
||||||
|
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
Loading…
Reference in New Issue
Block a user