Return position in info

This commit is contained in:
kngwyu 2020-07-06 00:52:14 +09:00
parent 266ef3b855
commit cb9dcc554e
2 changed files with 4 additions and 3 deletions

View File

@ -45,14 +45,14 @@ class AntEnv(AgentModel):
def __init__(
self,
file_path: Optional[str] = None,
ctrl_cost_weight: float = 0.5,
ctrl_cost_weight: float = 0.0001,
forward_reward_fn: ForwardRewardFn = forward_reward_vnorm,
) -> None:
self._ctrl_cost_weight = ctrl_cost_weight
self._forward_reward_fn = forward_reward_fn
super().__init__(file_path, 5)
def _forward_reward(self, xy_pos_before: np.ndarray) -> float:
def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]:
xy_pos_after = self.sim.data.qpos[:2].copy()
xy_velocity = (xy_pos_after - xy_pos_before) / self.dt
return self._forward_reward_fn(xy_velocity)

View File

@ -26,7 +26,7 @@ class MazeEnv(gym.Env):
def __init__(
self,
model_cls: Type[AgentModel],
maze_task: Type[maze_task.MazeTask] = maze_task.SingleGoalSparseUMaze,
maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask,
n_bins: int = 0,
sensor_range: float = 3.0,
sensor_span: float = 2 * np.pi,
@ -542,4 +542,5 @@ class MazeEnv(gym.Env):
inner_reward = self._inner_reward_scaling * inner_reward
outer_reward = self._task.reward(next_obs)
done = self._task.termination(next_obs)
info["position"] = self.wrapped_env.get_xy()
return next_obs, inner_reward + outer_reward, done, info