diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 94a48ab..1e508d4 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -230,14 +230,48 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) return init_ball_state +def plot_ball_traj(x_traj, y_traj, z_traj): + import matplotlib.pyplot as plt + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.plot(x_traj, y_traj, z_traj) + plt.show() + +def plot_ball_traj_2d(x_traj, y_traj): + import matplotlib.pyplot as plt + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(x_traj, y_traj) + plt.show() + +def plot_single_axis(traj): + import matplotlib.pyplot as plt + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(traj) + plt.show() if __name__ == "__main__": env = TableTennisEnv(enable_wind=True) - for _ in range(1000): + for _ in range(5): obs = env.reset() + x_pos = [] + y_pos = [] + z_pos = [] + x_vel = [] + y_vel = [] + z_vel = [] for _ in range(2000): - env.render("human") + # env.render("human") obs, reward, done, info = env.step(np.zeros(7)) + x_pos.append(env.data.joint("tar_x").qpos[0]) + y_pos.append(env.data.joint("tar_y").qpos[0]) + z_pos.append(env.data.joint("tar_z").qpos[0]) + x_vel.append(env.data.joint("tar_x").qvel[0]) + y_vel.append(env.data.joint("tar_y").qvel[0]) + z_vel.append(env.data.joint("tar_z").qvel[0]) # print(reward) if done: + plot_ball_traj_2d(x_pos, y_pos) + plot_single_axis(x_vel) break