.
This commit is contained in:
parent
9b3a5ccca5
commit
807cd0ec4d
@ -65,12 +65,12 @@ class AntEnv(AgentModel):
|
|||||||
super().__init__(file_path, 5, observation_space=self.observation_space)
|
super().__init__(file_path, 5, observation_space=self.observation_space)
|
||||||
|
|
||||||
def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]:
|
def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]:
|
||||||
xy_pos_after = self.sim.data.qpos[:2].copy()
|
xy_pos_after = self.data.qpos[:2].copy()
|
||||||
xy_velocity = (xy_pos_after - xy_pos_before) / self.dt
|
xy_velocity = (xy_pos_after - xy_pos_before) / self.dt
|
||||||
return self._forward_reward_fn(xy_velocity)
|
return self._forward_reward_fn(xy_velocity)
|
||||||
|
|
||||||
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
xy_pos_before = self.sim.data.qpos[:2].copy()
|
xy_pos_before = self.data.qpos[:2].copy()
|
||||||
self.do_simulation(action, self.frame_skip)
|
self.do_simulation(action, self.frame_skip)
|
||||||
|
|
||||||
forward_reward = self._forward_reward(xy_pos_before)
|
forward_reward = self._forward_reward(xy_pos_before)
|
||||||
@ -108,15 +108,15 @@ class AntEnv(AgentModel):
|
|||||||
|
|
||||||
def get_ori(self) -> np.ndarray:
|
def get_ori(self) -> np.ndarray:
|
||||||
ori = [0, 1, 0, 0]
|
ori = [0, 1, 0, 0]
|
||||||
rot = self.sim.data.qpos[self.ORI_IND : self.ORI_IND + 4] # take the quaternion
|
rot = self.data.qpos[self.ORI_IND : self.ORI_IND + 4] # take the quaternion
|
||||||
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
|
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
|
||||||
ori = np.arctan2(ori[1], ori[0])
|
ori = np.arctan2(ori[1], ori[0])
|
||||||
return ori
|
return ori
|
||||||
|
|
||||||
def set_xy(self, xy: np.ndarray) -> None:
|
def set_xy(self, xy: np.ndarray) -> None:
|
||||||
qpos = self.sim.data.qpos.copy()
|
qpos = self.data.qpos.copy()
|
||||||
qpos[:2] = xy
|
qpos[:2] = xy
|
||||||
self.set_state(qpos, self.sim.data.qvel)
|
self.set_state(qpos, self.data.qvel)
|
||||||
|
|
||||||
def get_xy(self) -> np.ndarray:
|
def get_xy(self) -> np.ndarray:
|
||||||
return np.copy(self.sim.data.qpos[:2])
|
return np.copy(self.data.qpos[:2])
|
||||||
|
Loading…
Reference in New Issue
Block a user