using artifical wind field
This commit is contained in:
parent
96f17e02cf
commit
b883ad63b7
@ -260,8 +260,9 @@ for ctxt_dim in [2, 4]:
|
|||||||
"ctxt_dim": ctxt_dim,
|
"ctxt_dim": ctxt_dim,
|
||||||
'frame_skip': 4,
|
'frame_skip': 4,
|
||||||
'enable_wind': False,
|
'enable_wind': False,
|
||||||
'enable_switching_goal': True,
|
'enable_switching_goal': False,
|
||||||
'enable_air': False,
|
'enable_air': False,
|
||||||
|
'enable_artifical_wind': True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
[False] * 7, # joints velocity
|
[False] * 7, # joints velocity
|
||||||
[True] * 2, # position ball x, y
|
[True] * 2, # position ball x, y
|
||||||
[False] * 1, # position ball z
|
[False] * 1, # position ball z
|
||||||
# [False] * 3, # velocity ball x, y, z
|
[True] * 3, # velocity ball x, y, z
|
||||||
[True] * 2, # target landing position
|
[True] * 2, # target landing position
|
||||||
# [True] * 1, # time
|
# [True] * 1, # time
|
||||||
])
|
])
|
||||||
|
@ -24,7 +24,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
||||||
enable_switching_goal: bool = False,
|
enable_switching_goal: bool = False,
|
||||||
enable_wind: bool = False, enable_magnus: bool = False,
|
enable_wind: bool = False,
|
||||||
|
enable_artifical_wind: bool = False,
|
||||||
|
enable_magnus: bool = False,
|
||||||
enable_air: bool = False):
|
enable_air: bool = False):
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
@ -48,6 +50,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._enable_goal_switching = enable_switching_goal
|
self._enable_goal_switching = enable_switching_goal
|
||||||
|
|
||||||
|
self._enable_artifical_wind = enable_artifical_wind
|
||||||
|
|
||||||
|
self._artifical_force = 0.
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
@ -86,6 +92,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
unstable_simulation = False
|
unstable_simulation = False
|
||||||
|
|
||||||
|
|
||||||
if self._enable_goal_switching:
|
if self._enable_goal_switching:
|
||||||
if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5:
|
if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5:
|
||||||
new_goal_pos = self._generate_goal_pos(random=True)
|
new_goal_pos = self._generate_goal_pos(random=True)
|
||||||
@ -95,6 +102,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
for _ in range(self.frame_skip):
|
for _ in range(self.frame_skip):
|
||||||
|
if self._enable_artifical_wind:
|
||||||
|
self.data.qfrc_applied[-2] = self._artifical_force
|
||||||
try:
|
try:
|
||||||
self.do_simulation(action, 1)
|
self.do_simulation(action, 1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -154,7 +163,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
||||||
self._init_ball_state[2] = 1.85
|
# self._init_ball_state[2] = 1.85
|
||||||
self._goal_pos = self._generate_goal_pos(random=True)
|
self._goal_pos = self._generate_goal_pos(random=True)
|
||||||
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
||||||
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
||||||
@ -163,6 +172,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.data.joint("tar_y").qvel = self._init_ball_state[4]
|
self.data.joint("tar_y").qvel = self._init_ball_state[4]
|
||||||
self.data.joint("tar_z").qvel = self._init_ball_state[5]
|
self.data.joint("tar_z").qvel = self._init_ball_state[5]
|
||||||
|
|
||||||
|
if self._enable_artifical_wind:
|
||||||
|
self._artifical_force = self.np_random.uniform(low=-0.1, high=0.1)
|
||||||
|
|
||||||
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
||||||
|
|
||||||
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
|
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
|
||||||
@ -196,9 +208,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.data.joint("tar_x").qpos.copy(),
|
self.data.joint("tar_x").qpos.copy(),
|
||||||
self.data.joint("tar_y").qpos.copy(),
|
self.data.joint("tar_y").qpos.copy(),
|
||||||
self.data.joint("tar_z").qpos.copy(),
|
self.data.joint("tar_z").qpos.copy(),
|
||||||
# self.data.joint("tar_x").qvel.copy(),
|
self.data.joint("tar_x").qvel.copy(),
|
||||||
# self.data.joint("tar_y").qvel.copy(),
|
self.data.joint("tar_y").qvel.copy(),
|
||||||
# self.data.joint("tar_z").qvel.copy(),
|
self.data.joint("tar_z").qvel.copy(),
|
||||||
# self.data.body("target_ball").xvel.copy(),
|
# self.data.body("target_ball").xvel.copy(),
|
||||||
self._goal_pos.copy(),
|
self._goal_pos.copy(),
|
||||||
])
|
])
|
||||||
@ -265,7 +277,7 @@ def plot_compare_trajs(traj1, traj2, title):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env_air = TableTennisEnv(enable_air=True, enable_wind=False)
|
env_air = TableTennisEnv(enable_air=False, enable_wind=False, enable_artifical_wind=True)
|
||||||
env_no_air = TableTennisEnv(enable_air=False, enable_wind=False)
|
env_no_air = TableTennisEnv(enable_air=False, enable_wind=False)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
obs1 = env_air.reset()
|
obs1 = env_air.reset()
|
||||||
@ -279,18 +291,17 @@ if __name__ == "__main__":
|
|||||||
# y_vel = []
|
# y_vel = []
|
||||||
# z_vel = []
|
# z_vel = []
|
||||||
for _ in range(2000):
|
for _ in range(2000):
|
||||||
# env_air.render("human")
|
env_air.render("human")
|
||||||
obs1, reward1, done1, info1 = env_air.step(np.zeros(7))
|
obs1, reward1, done1, info1 = env_air.step(np.zeros(7))
|
||||||
obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7))
|
obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7))
|
||||||
# _, _, _, _ = env_no_air.step(np.zeros(7))
|
# # _, _, _, _ = env_no_air.step(np.zeros(7))
|
||||||
air_x_pos.append(env_air.data.joint("tar_z").qpos[0])
|
air_x_pos.append(env_air.data.joint("tar_z").qpos[0])
|
||||||
no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0])
|
no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0])
|
||||||
# z_pos.append(env.data.joint("tar_z").qpos[0])
|
# # z_pos.append(env.data.joint("tar_z").qpos[0])
|
||||||
# x_vel.append(env.data.joint("tar_x").qvel[0])
|
# # x_vel.append(env.data.joint("tar_x").qvel[0])
|
||||||
# y_vel.append(env.data.joint("tar_y").qvel[0])
|
# # y_vel.append(env.data.joint("tar_y").qvel[0])
|
||||||
# z_vel.append(env.data.joint("tar_z").qvel[0])
|
# # z_vel.append(env.data.joint("tar_z").qvel[0])
|
||||||
# print(reward)
|
# # print(reward)
|
||||||
if info1["num_steps"] == 150:
|
if info1["num_steps"] == 150:
|
||||||
# plot_ball_traj_2d(x_pos, y_pos)
|
plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out wind")
|
||||||
plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out air")
|
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user