From 807cd0ec4ddb5eee97c13df4a276ffc0cb86303d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 22 Apr 2024 20:01:58 +0200 Subject: [PATCH] . --- mujoco_maze/ant.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mujoco_maze/ant.py b/mujoco_maze/ant.py index d9dacdd..bb8448f 100644 --- a/mujoco_maze/ant.py +++ b/mujoco_maze/ant.py @@ -65,12 +65,12 @@ class AntEnv(AgentModel): super().__init__(file_path, 5, observation_space=self.observation_space) def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]: - xy_pos_after = self.sim.data.qpos[:2].copy() + xy_pos_after = self.data.qpos[:2].copy() xy_velocity = (xy_pos_after - xy_pos_before) / self.dt return self._forward_reward_fn(xy_velocity) def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: - xy_pos_before = self.sim.data.qpos[:2].copy() + xy_pos_before = self.data.qpos[:2].copy() self.do_simulation(action, self.frame_skip) forward_reward = self._forward_reward(xy_pos_before) @@ -108,15 +108,15 @@ class AntEnv(AgentModel): def get_ori(self) -> np.ndarray: ori = [0, 1, 0, 0] - rot = self.sim.data.qpos[self.ORI_IND : self.ORI_IND + 4] # take the quaternion + rot = self.data.qpos[self.ORI_IND : self.ORI_IND + 4] # take the quaternion ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane ori = np.arctan2(ori[1], ori[0]) return ori def set_xy(self, xy: np.ndarray) -> None: - qpos = self.sim.data.qpos.copy() + qpos = self.data.qpos.copy() qpos[:2] = xy - self.set_state(qpos, self.sim.data.qvel) + self.set_state(qpos, self.data.qvel) def get_xy(self) -> np.ndarray: - return np.copy(self.sim.data.qpos[:2]) + return np.copy(self.data.qpos[:2])