Return position in info
This commit is contained in:
parent
266ef3b855
commit
cb9dcc554e
@ -45,14 +45,14 @@ class AntEnv(AgentModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_path: Optional[str] = None,
|
file_path: Optional[str] = None,
|
||||||
ctrl_cost_weight: float = 0.5,
|
ctrl_cost_weight: float = 0.0001,
|
||||||
forward_reward_fn: ForwardRewardFn = forward_reward_vnorm,
|
forward_reward_fn: ForwardRewardFn = forward_reward_vnorm,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._ctrl_cost_weight = ctrl_cost_weight
|
self._ctrl_cost_weight = ctrl_cost_weight
|
||||||
self._forward_reward_fn = forward_reward_fn
|
self._forward_reward_fn = forward_reward_fn
|
||||||
super().__init__(file_path, 5)
|
super().__init__(file_path, 5)
|
||||||
|
|
||||||
def _forward_reward(self, xy_pos_before: np.ndarray) -> float:
|
def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]:
|
||||||
xy_pos_after = self.sim.data.qpos[:2].copy()
|
xy_pos_after = self.sim.data.qpos[:2].copy()
|
||||||
xy_velocity = (xy_pos_after - xy_pos_before) / self.dt
|
xy_velocity = (xy_pos_after - xy_pos_before) / self.dt
|
||||||
return self._forward_reward_fn(xy_velocity)
|
return self._forward_reward_fn(xy_velocity)
|
||||||
|
@ -26,7 +26,7 @@ class MazeEnv(gym.Env):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_cls: Type[AgentModel],
|
model_cls: Type[AgentModel],
|
||||||
maze_task: Type[maze_task.MazeTask] = maze_task.SingleGoalSparseUMaze,
|
maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask,
|
||||||
n_bins: int = 0,
|
n_bins: int = 0,
|
||||||
sensor_range: float = 3.0,
|
sensor_range: float = 3.0,
|
||||||
sensor_span: float = 2 * np.pi,
|
sensor_span: float = 2 * np.pi,
|
||||||
@ -542,4 +542,5 @@ class MazeEnv(gym.Env):
|
|||||||
inner_reward = self._inner_reward_scaling * inner_reward
|
inner_reward = self._inner_reward_scaling * inner_reward
|
||||||
outer_reward = self._task.reward(next_obs)
|
outer_reward = self._task.reward(next_obs)
|
||||||
done = self._task.termination(next_obs)
|
done = self._task.termination(next_obs)
|
||||||
|
info["position"] = self.wrapped_env.get_xy()
|
||||||
return next_obs, inner_reward + outer_reward, done, info
|
return next_obs, inner_reward + outer_reward, done, info
|
||||||
|
Loading…
Reference in New Issue
Block a user