From 87633a89fb3abaeedc1fc4e228f37ac3a7aaa083 Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Fri, 11 Nov 2022 23:41:35 +0100 Subject: [PATCH] add enable_wind option to table tennis environment --- .../envs/mujoco/table_tennis/table_tennis_env.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index cbe52da..507513f 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -20,7 +20,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): 7 DoF table tennis environment """ - def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4): + def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4, + enable_wind: bool = False, enable_magnus: bool = False): utils.EzPickle.__init__(**locals()) self._steps = 0 @@ -54,6 +55,13 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) + # complex dynamics settings + self.model.opt.density = 1.225 + self.model.opt.viscosity = 2.27e-5 + self._enable_wind = enable_wind + self._enable_magnus = enable_magnus + self._wind_vel = np.zeros(3) + def _set_ids(self): self._floor_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "floor") self._ball_contact_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_GEOM, "target_ball_contact") @@ -142,6 +150,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): mujoco.mj_forward(self.model, self.data) + if self._enable_wind: + self._wind_vel[1] = self.np_random.uniform(low=-10, high=10, size=1) + self.model.opt.wind[:3] = self._wind_vel + self._hit_ball = False self._ball_land_on_table = False self._ball_contact_after_hit = False @@ -208,7 +220,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): if __name__ == "__main__": - env = TableTennisEnv() + env = TableTennisEnv(enable_wind=True) for _ in range(1000): obs = env.reset() for _ in range(2000):