From 409ee44568579633894f88bd3f6c222cd5d3157e Mon Sep 17 00:00:00 2001 From: kngwyu Date: Fri, 7 May 2021 19:10:19 +0900 Subject: [PATCH] Refactor maze_env --- mujoco_maze/maze_env.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 6b89d87..d78b3bb 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -27,7 +27,7 @@ class MazeEnv(gym.Env): self, model_cls: Type[AgentModel], maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask, - top_down_view: float = False, + include_position: bool = True, maze_height: float = 0.5, maze_size_scaling: float = 4.0, inner_reward_scaling: float = 1.0, @@ -36,16 +36,11 @@ class MazeEnv(gym.Env): websock_port: Optional[int] = None, **kwargs, ) -> None: + self.t = 0 # time steps self._task = maze_task(maze_size_scaling, **task_kwargs) - - xml_path = os.path.join(MODEL_DIR, model_cls.FILE) - tree = ET.parse(xml_path) - worldbody = tree.find(".//worldbody") - self._maze_height = height = maze_height self._maze_size_scaling = size_scaling = maze_size_scaling self._inner_reward_scaling = inner_reward_scaling - self.t = 0 # time steps self._observe_blocks = self._task.OBSERVE_BLOCKS self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT # Observe other objectives @@ -94,6 +89,11 @@ class MazeEnv(gym.Env): # walls (immovable), chasms (fall), movable blocks self._view = np.zeros([5, 5, 3]) + # Let's create MuJoCo XML + xml_path = os.path.join(MODEL_DIR, model_cls.FILE) + tree = ET.parse(xml_path) + worldbody = tree.find(".//worldbody") + height_offset = 0.0 if self.elevated: # Increase initial z-pos of ant.