diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index 275bba1..2408404 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -13,6 +13,7 @@ MAX_EPISODE_STEPS_BOX_PUSHING = 100 BOX_POS_BOUND = np.array([[0.3, -0.45, -0.01], [0.6, 0.45, -0.01]]) + class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): """ franka box pushing environment @@ -41,8 +42,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): self._episode_energy = 0. MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"), - frame_skip=self.frame_skip, - mujoco_bindings="mujoco") + frame_skip=self.frame_skip) self.action_space = spaces.Box(low=-1, high=1, shape=(7,)) def step(self, action): @@ -246,7 +246,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): old_err_norm = err_norm - ### get Jacobian by mujoco + # get Jacobian by mujoco self.data.qpos[:7] = q mujoco.mj_forward(self.model, self.data) @@ -280,6 +280,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): return q + class BoxPushingDense(BoxPushingEnvBase): def __init__(self, frame_skip: int = 10): super(BoxPushingDense, self).__init__(frame_skip=frame_skip) @@ -295,7 +296,7 @@ class BoxPushingDense(BoxPushingEnvBase): energy_cost = -0.0005 * np.sum(np.square(action)) reward = joint_penalty + tcp_box_dist_reward + \ - box_goal_pos_dist_reward + box_goal_rot_dist_reward + energy_cost + box_goal_pos_dist_reward + box_goal_rot_dist_reward + energy_cost rod_inclined_angle = rotation_distance(rod_quat, self._desired_rod_quat) if rod_inclined_angle > np.pi / 4: @@ -303,6 +304,7 @@ class BoxPushingDense(BoxPushingEnvBase): return reward + class BoxPushingTemporalSparse(BoxPushingEnvBase): def __init__(self, frame_skip: int = 10): super(BoxPushingTemporalSparse, self).__init__(frame_skip=frame_skip) @@ -331,6 +333,7 @@ class BoxPushingTemporalSparse(BoxPushingEnvBase): return reward + class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): def __init__(self, frame_skip: int = 10):