From 96f17e02cfe4e5dee6dbc36bb857ed5b63531107 Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Thu, 24 Nov 2022 14:15:09 +0100 Subject: [PATCH] random sampling for goal switching & adjust height for initial ball state --- fancy_gym/envs/__init__.py | 3 +- .../mujoco/table_tennis/table_tennis_env.py | 61 +++++++++++-------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index b3ba3aa..f74bdcd 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -260,7 +260,8 @@ for ctxt_dim in [2, 4]: "ctxt_dim": ctxt_dim, 'frame_skip': 4, 'enable_wind': False, - 'enable_switching_goal': False, + 'enable_switching_goal': True, + 'enable_air': False, } ) diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 72b92e0..7a77443 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -87,8 +87,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): unstable_simulation = False if self._enable_goal_switching: - if self._steps == 45 and self.np_random.uniform(0, 1) < 0.5: - self._goal_pos[1] = -self._goal_pos[1] + if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5: + new_goal_pos = self._generate_goal_pos(random=True) + new_goal_pos[1] = -new_goal_pos[1] + self._goal_pos = new_goal_pos self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]]) mujoco.mj_forward(self.model, self.data) @@ -151,8 +153,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def reset_model(self): self._steps = 0 - self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False) - self._goal_pos = self._generate_goal_pos(random=False) + self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False) + self._init_ball_state[2] = 1.85 + self._goal_pos = self._generate_goal_pos(random=True) self.data.joint("tar_x").qpos = self._init_ball_state[0] self.data.joint("tar_y").qpos = self._init_ball_state[1] self.data.joint("tar_z").qpos = self._init_ball_state[2] @@ -167,7 +170,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): mujoco.mj_forward(self.model, self.data) if self._enable_wind: - self._wind_vel[1] = self.np_random.uniform(low=-5, high=5, size=1) + self._wind_vel[1] = self.np_random.uniform(low=-10, high=10, size=1) self.model.opt.wind[:3] = self._wind_vel self._hit_ball = False @@ -251,37 +254,43 @@ def plot_ball_traj_2d(x_traj, y_traj): ax.plot(x_traj, y_traj) plt.show() -def plot_single_axis(traj, title): +def plot_compare_trajs(traj1, traj2, title): import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(111) - ax.plot(traj) + ax.plot(traj1, color='r', label='traj1') + ax.plot(traj2, color='b', label='traj2') ax.set_title(title) + plt.legend() plt.show() if __name__ == "__main__": - env = TableTennisEnv(enable_air=True) - # env_with_air = TableTennisEnv(enable_air=True) - for _ in range(1): - obs1 = env.reset() + env_air = TableTennisEnv(enable_air=True, enable_wind=False) + env_no_air = TableTennisEnv(enable_air=False, enable_wind=False) + for _ in range(10): + obs1 = env_air.reset() + obs2 = env_no_air.reset() # obs2 = env_with_air.reset() - x_pos = [] - y_pos = [] - z_pos = [] - x_vel = [] - y_vel = [] - z_vel = [] + air_x_pos = [] + no_air_x_pos = [] + # y_pos = [] + # z_pos = [] + # x_vel = [] + # y_vel = [] + # z_vel = [] for _ in range(2000): - obs, reward, done, info = env.step(np.zeros(7)) + # env_air.render("human") + obs1, reward1, done1, info1 = env_air.step(np.zeros(7)) + obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7)) # _, _, _, _ = env_no_air.step(np.zeros(7)) - x_pos.append(env.data.joint("tar_x").qpos[0]) - y_pos.append(env.data.joint("tar_y").qpos[0]) - z_pos.append(env.data.joint("tar_z").qpos[0]) - x_vel.append(env.data.joint("tar_x").qvel[0]) - y_vel.append(env.data.joint("tar_y").qvel[0]) - z_vel.append(env.data.joint("tar_z").qvel[0]) + air_x_pos.append(env_air.data.joint("tar_z").qpos[0]) + no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0]) + # z_pos.append(env.data.joint("tar_z").qpos[0]) + # x_vel.append(env.data.joint("tar_x").qvel[0]) + # y_vel.append(env.data.joint("tar_y").qvel[0]) + # z_vel.append(env.data.joint("tar_z").qvel[0]) # print(reward) - if done: + if info1["num_steps"] == 150: # plot_ball_traj_2d(x_pos, y_pos) - plot_single_axis(x_pos, title="x_vel without air") + plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out air") break