diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index f74bdcd..58512bc 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -260,8 +260,9 @@ for ctxt_dim in [2, 4]: "ctxt_dim": ctxt_dim, 'frame_skip': 4, 'enable_wind': False, - 'enable_switching_goal': True, + 'enable_switching_goal': False, 'enable_air': False, + 'enable_artifical_wind': True, } ) diff --git a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py index fa26c90..b3519af 100644 --- a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py @@ -16,7 +16,7 @@ class MPWrapper(RawInterfaceWrapper): [False] * 7, # joints velocity [True] * 2, # position ball x, y [False] * 1, # position ball z - # [False] * 3, # velocity ball x, y, z + [True] * 3, # velocity ball x, y, z [True] * 2, # target landing position # [True] * 1, # time ]) diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 7a77443..5283e54 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -24,7 +24,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, enable_switching_goal: bool = False, - enable_wind: bool = False, enable_magnus: bool = False, + enable_wind: bool = False, + enable_artifical_wind: bool = False, + enable_magnus: bool = False, enable_air: bool = False): utils.EzPickle.__init__(**locals()) self._steps = 0 @@ -48,6 +50,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self._enable_goal_switching = enable_switching_goal + self._enable_artifical_wind = enable_artifical_wind + + self._artifical_force = 0. + MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), frame_skip=frame_skip, @@ -86,6 +92,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): unstable_simulation = False + if self._enable_goal_switching: if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5: new_goal_pos = self._generate_goal_pos(random=True) @@ -95,6 +102,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): mujoco.mj_forward(self.model, self.data) for _ in range(self.frame_skip): + if self._enable_artifical_wind: + self.data.qfrc_applied[-2] = self._artifical_force try: self.do_simulation(action, 1) except Exception as e: @@ -154,7 +163,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def reset_model(self): self._steps = 0 self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False) - self._init_ball_state[2] = 1.85 + # self._init_ball_state[2] = 1.85 self._goal_pos = self._generate_goal_pos(random=True) self.data.joint("tar_x").qpos = self._init_ball_state[0] self.data.joint("tar_y").qpos = self._init_ball_state[1] @@ -163,6 +172,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self.data.joint("tar_y").qvel = self._init_ball_state[4] self.data.joint("tar_z").qvel = self._init_ball_state[5] + if self._enable_artifical_wind: + self._artifical_force = self.np_random.uniform(low=-0.1, high=0.1) + self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]]) self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5]) @@ -196,9 +208,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self.data.joint("tar_x").qpos.copy(), self.data.joint("tar_y").qpos.copy(), self.data.joint("tar_z").qpos.copy(), - # self.data.joint("tar_x").qvel.copy(), - # self.data.joint("tar_y").qvel.copy(), - # self.data.joint("tar_z").qvel.copy(), + self.data.joint("tar_x").qvel.copy(), + self.data.joint("tar_y").qvel.copy(), + self.data.joint("tar_z").qvel.copy(), # self.data.body("target_ball").xvel.copy(), self._goal_pos.copy(), ]) @@ -265,7 +277,7 @@ def plot_compare_trajs(traj1, traj2, title): plt.show() if __name__ == "__main__": - env_air = TableTennisEnv(enable_air=True, enable_wind=False) + env_air = TableTennisEnv(enable_air=False, enable_wind=False, enable_artifical_wind=True) env_no_air = TableTennisEnv(enable_air=False, enable_wind=False) for _ in range(10): obs1 = env_air.reset() @@ -279,18 +291,17 @@ if __name__ == "__main__": # y_vel = [] # z_vel = [] for _ in range(2000): - # env_air.render("human") + env_air.render("human") obs1, reward1, done1, info1 = env_air.step(np.zeros(7)) obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7)) - # _, _, _, _ = env_no_air.step(np.zeros(7)) + # # _, _, _, _ = env_no_air.step(np.zeros(7)) air_x_pos.append(env_air.data.joint("tar_z").qpos[0]) no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0]) - # z_pos.append(env.data.joint("tar_z").qpos[0]) - # x_vel.append(env.data.joint("tar_x").qvel[0]) - # y_vel.append(env.data.joint("tar_y").qvel[0]) - # z_vel.append(env.data.joint("tar_z").qvel[0]) - # print(reward) + # # z_pos.append(env.data.joint("tar_z").qpos[0]) + # # x_vel.append(env.data.joint("tar_x").qvel[0]) + # # y_vel.append(env.data.joint("tar_y").qvel[0]) + # # z_vel.append(env.data.joint("tar_z").qvel[0]) + # # print(reward) if info1["num_steps"] == 150: - # plot_ball_traj_2d(x_pos, y_pos) - plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out air") + plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out wind") break