add plot for wind influence testing

This commit is contained in:
Hongyi Zhou 2022-11-15 23:38:50 +01:00
parent 7ba490f14a
commit f9c0c1f3ab

View File

@ -230,14 +230,48 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state return init_ball_state
def plot_ball_traj(x_traj, y_traj, z_traj):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(x_traj, y_traj, z_traj)
plt.show()
def plot_ball_traj_2d(x_traj, y_traj):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x_traj, y_traj)
plt.show()
def plot_single_axis(traj):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(traj)
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
env = TableTennisEnv(enable_wind=True) env = TableTennisEnv(enable_wind=True)
for _ in range(1000): for _ in range(5):
obs = env.reset() obs = env.reset()
x_pos = []
y_pos = []
z_pos = []
x_vel = []
y_vel = []
z_vel = []
for _ in range(2000): for _ in range(2000):
env.render("human") # env.render("human")
obs, reward, done, info = env.step(np.zeros(7)) obs, reward, done, info = env.step(np.zeros(7))
x_pos.append(env.data.joint("tar_x").qpos[0])
y_pos.append(env.data.joint("tar_y").qpos[0])
z_pos.append(env.data.joint("tar_z").qpos[0])
x_vel.append(env.data.joint("tar_x").qvel[0])
y_vel.append(env.data.joint("tar_y").qvel[0])
z_vel.append(env.data.joint("tar_z").qvel[0])
# print(reward) # print(reward)
if done: if done:
plot_ball_traj_2d(x_pos, y_pos)
plot_single_axis(x_vel)
break break