fix bug in box pushing IK
This commit is contained in:
parent
e3d36dead0
commit
4a850912be
@ -71,6 +71,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
qpos = self.data.qpos[:7].copy()
|
||||
qvel = self.data.qvel[:7].copy()
|
||||
|
||||
if (self._steps + 1) % 10 == 0:
|
||||
self.calculateOfflineIK(np.array([0.4, 0.3, 0.14]), np.array([0, 1, 0, 0]))
|
||||
|
||||
if not unstable_simulation:
|
||||
reward = self._get_reward(episode_end, box_pos, box_quat, target_pos, target_quat,
|
||||
rod_tip_pos, rod_quat, qpos, qvel, action)
|
||||
@ -93,7 +96,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
def reset_model(self):
|
||||
# rest box to initial position
|
||||
self.set_state(self.init_qpos_box_pushing, self.init_qvel_box_pushing)
|
||||
box_init_pos = np.array([0.4, 0.3, -0.01, 0.0, 0.0, 0.0, 1.0])
|
||||
random_init_pos = self.sample_context()
|
||||
# box_init_pos = np.array([0.4, 0.3, -0.01, 0.0, 0.0, 0.0, 1.0])
|
||||
box_init_pos = random_init_pos
|
||||
self.data.joint("box_joint").qpos = box_init_pos
|
||||
|
||||
# set target position
|
||||
@ -219,6 +224,10 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
q_old = q
|
||||
q = q + dt * qd_d
|
||||
q = np.clip(q, q_min, q_max)
|
||||
|
||||
self.data.qpos[:7] = q
|
||||
mujoco.mj_forward(self.model, self.data)
|
||||
|
||||
current_cart_pos = self.data.body("tcp").xpos.copy()
|
||||
current_cart_quat = self.data.body("tcp").xquat.copy()
|
||||
|
||||
@ -230,20 +239,30 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
|
||||
err = np.hstack((cart_pos_error, cart_quat_error))
|
||||
err_norm = np.sum(cart_pos_error**2) + np.sum((current_cart_quat - desired_cart_quat)**2)
|
||||
|
||||
if err_norm > old_err_norm:
|
||||
# old_err_norm = err_norm
|
||||
q = q_old
|
||||
dt = 0.7 * dt
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
dt = 1.025 * dt
|
||||
|
||||
i += 1
|
||||
|
||||
if err_norm < eps:
|
||||
print("IK converged in {} iterations".format(i))
|
||||
break
|
||||
|
||||
if i > IT_MAX:
|
||||
print("IK did not converge in {} iterations".format(i))
|
||||
break
|
||||
|
||||
old_err_norm = err_norm
|
||||
|
||||
|
||||
### get Jacobian by mujoco
|
||||
self.data.qpos[:7] = q
|
||||
mujoco.mj_forward(self.model, self.data)
|
||||
@ -272,7 +291,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
|
||||
qd_d = w.dot(J.transpose()).dot(qd_d) + qd_null
|
||||
|
||||
i += 1
|
||||
# i += 1
|
||||
|
||||
return q
|
||||
|
||||
@ -360,10 +379,10 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
||||
if __name__=="__main__":
|
||||
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
||||
env.reset()
|
||||
for i in range(10):
|
||||
for i in range(100):
|
||||
env.reset()
|
||||
for _ in range(100):
|
||||
env.render("human")
|
||||
action = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(action)
|
||||
print("info: {}".format(info))
|
||||
# print("info: {}".format(info))
|
||||
|
@ -29,24 +29,24 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.data.qvel[:7].copy()
|
||||
|
||||
def check_time_validity(self, action):
|
||||
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
|
||||
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
|
||||
|
||||
def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
|
||||
-> Tuple[np.ndarray, float, bool, dict]:
|
||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
|
||||
obs = np.concatenate([self.get_obs(), np.array([0])])
|
||||
return obs, -invalid_penalty, True, {
|
||||
"hit_ball": [False],
|
||||
"ball_returned_success": [False],
|
||||
"land_dist_error": [10.],
|
||||
"is_success": [False],
|
||||
'trajectory_length': 1,
|
||||
"num_steps": [1]
|
||||
}
|
||||
# def check_time_validity(self, action):
|
||||
# return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
|
||||
# and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
|
||||
#
|
||||
# def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
|
||||
# -> Tuple[np.ndarray, float, bool, dict]:
|
||||
# tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||
# delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||
# invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
|
||||
# obs = np.concatenate([self.get_obs(), np.array([0])])
|
||||
# return obs, -invalid_penalty, True, {
|
||||
# "hit_ball": [False],
|
||||
# "ball_returned_success": [False],
|
||||
# "land_dist_error": [10.],
|
||||
# "is_success": [False],
|
||||
# 'trajectory_length': 1,
|
||||
# "num_steps": [1]
|
||||
# }
|
||||
|
||||
def episode_callback(self, action, pos_traj, vel_traj):
|
||||
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||
|
22
fancy_gym/examples/table_tennis_reproducebility.py
Normal file
22
fancy_gym/examples/table_tennis_reproducebility.py
Normal file
@ -0,0 +1,22 @@
|
||||
import fancy_gym
|
||||
import numpy as np
|
||||
|
||||
env_1 = fancy_gym.make("TableTennis4DProDMP-v0", seed=0)
|
||||
env_2 = fancy_gym.make("TableTennis4DProDMP-v0", seed=0)
|
||||
|
||||
obs_1 = env_1.reset()
|
||||
obs_2 = env_2.reset()
|
||||
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
||||
for i in range(100000):
|
||||
action = env_1.action_space.sample()
|
||||
obs_1, reward_1, done_1, info_1 = env_1.step(action)
|
||||
obs_2, reward_2, done_2, info_2 = env_2.step(action)
|
||||
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
||||
assert np.all(reward_1 == reward_2), "The rewards should be the same"
|
||||
assert np.all(done_1 == done_2), "The done flags should be the same"
|
||||
for key in info_1:
|
||||
assert np.all(info_1[key] == info_2[key]), f"The info fields: {key} should be the same"
|
||||
if done_1 and done_2:
|
||||
obs_1 = env_1.reset()
|
||||
obs_2 = env_2.reset()
|
||||
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
Loading…
Reference in New Issue
Block a user