This commit is contained in:
Dominik Moritz Roth 2024-04-22 20:01:58 +02:00
parent 9b3a5ccca5
commit 807cd0ec4d

View File

@ -65,12 +65,12 @@ class AntEnv(AgentModel):
super().__init__(file_path, 5, observation_space=self.observation_space) super().__init__(file_path, 5, observation_space=self.observation_space)
def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]: def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]:
xy_pos_after = self.sim.data.qpos[:2].copy() xy_pos_after = self.data.qpos[:2].copy()
xy_velocity = (xy_pos_after - xy_pos_before) / self.dt xy_velocity = (xy_pos_after - xy_pos_before) / self.dt
return self._forward_reward_fn(xy_velocity) return self._forward_reward_fn(xy_velocity)
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
xy_pos_before = self.sim.data.qpos[:2].copy() xy_pos_before = self.data.qpos[:2].copy()
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
forward_reward = self._forward_reward(xy_pos_before) forward_reward = self._forward_reward(xy_pos_before)
@ -108,15 +108,15 @@ class AntEnv(AgentModel):
def get_ori(self) -> np.ndarray: def get_ori(self) -> np.ndarray:
ori = [0, 1, 0, 0] ori = [0, 1, 0, 0]
rot = self.sim.data.qpos[self.ORI_IND : self.ORI_IND + 4] # take the quaternion rot = self.data.qpos[self.ORI_IND : self.ORI_IND + 4] # take the quaternion
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
ori = np.arctan2(ori[1], ori[0]) ori = np.arctan2(ori[1], ori[0])
return ori return ori
def set_xy(self, xy: np.ndarray) -> None: def set_xy(self, xy: np.ndarray) -> None:
qpos = self.sim.data.qpos.copy() qpos = self.data.qpos.copy()
qpos[:2] = xy qpos[:2] = xy
self.set_state(qpos, self.sim.data.qvel) self.set_state(qpos, self.data.qvel)
def get_xy(self) -> np.ndarray: def get_xy(self) -> np.ndarray:
return np.copy(self.sim.data.qpos[:2]) return np.copy(self.data.qpos[:2])