Refactor maze_env
This commit is contained in:
parent
4087203b06
commit
409ee44568
@ -27,7 +27,7 @@ class MazeEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
model_cls: Type[AgentModel],
|
model_cls: Type[AgentModel],
|
||||||
maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask,
|
maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask,
|
||||||
top_down_view: float = False,
|
include_position: bool = True,
|
||||||
maze_height: float = 0.5,
|
maze_height: float = 0.5,
|
||||||
maze_size_scaling: float = 4.0,
|
maze_size_scaling: float = 4.0,
|
||||||
inner_reward_scaling: float = 1.0,
|
inner_reward_scaling: float = 1.0,
|
||||||
@ -36,16 +36,11 @@ class MazeEnv(gym.Env):
|
|||||||
websock_port: Optional[int] = None,
|
websock_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.t = 0 # time steps
|
||||||
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
||||||
|
|
||||||
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
|
|
||||||
tree = ET.parse(xml_path)
|
|
||||||
worldbody = tree.find(".//worldbody")
|
|
||||||
|
|
||||||
self._maze_height = height = maze_height
|
self._maze_height = height = maze_height
|
||||||
self._maze_size_scaling = size_scaling = maze_size_scaling
|
self._maze_size_scaling = size_scaling = maze_size_scaling
|
||||||
self._inner_reward_scaling = inner_reward_scaling
|
self._inner_reward_scaling = inner_reward_scaling
|
||||||
self.t = 0 # time steps
|
|
||||||
self._observe_blocks = self._task.OBSERVE_BLOCKS
|
self._observe_blocks = self._task.OBSERVE_BLOCKS
|
||||||
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
||||||
# Observe other objectives
|
# Observe other objectives
|
||||||
@ -94,6 +89,11 @@ class MazeEnv(gym.Env):
|
|||||||
# walls (immovable), chasms (fall), movable blocks
|
# walls (immovable), chasms (fall), movable blocks
|
||||||
self._view = np.zeros([5, 5, 3])
|
self._view = np.zeros([5, 5, 3])
|
||||||
|
|
||||||
|
# Let's create MuJoCo XML
|
||||||
|
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
|
||||||
|
tree = ET.parse(xml_path)
|
||||||
|
worldbody = tree.find(".//worldbody")
|
||||||
|
|
||||||
height_offset = 0.0
|
height_offset = 0.0
|
||||||
if self.elevated:
|
if self.elevated:
|
||||||
# Increase initial z-pos of ant.
|
# Increase initial z-pos of ant.
|
||||||
|
Loading…
Reference in New Issue
Block a user