Refactor maze_env

This commit is contained in:
kngwyu 2021-05-07 19:10:19 +09:00
parent 4087203b06
commit 409ee44568

View File

@ -27,7 +27,7 @@ class MazeEnv(gym.Env):
self,
model_cls: Type[AgentModel],
maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask,
top_down_view: float = False,
include_position: bool = True,
maze_height: float = 0.5,
maze_size_scaling: float = 4.0,
inner_reward_scaling: float = 1.0,
@ -36,16 +36,11 @@ class MazeEnv(gym.Env):
websock_port: Optional[int] = None,
**kwargs,
) -> None:
self.t = 0 # time steps
self._task = maze_task(maze_size_scaling, **task_kwargs)
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody")
self._maze_height = height = maze_height
self._maze_size_scaling = size_scaling = maze_size_scaling
self._inner_reward_scaling = inner_reward_scaling
self.t = 0 # time steps
self._observe_blocks = self._task.OBSERVE_BLOCKS
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
# Observe other objectives
@ -94,6 +89,11 @@ class MazeEnv(gym.Env):
# walls (immovable), chasms (fall), movable blocks
self._view = np.zeros([5, 5, 3])
# Let's create MuJoCo XML
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody")
height_offset = 0.0
if self.elevated:
# Increase initial z-pos of ant.