Refactor maze_env
This commit is contained in:
parent
4087203b06
commit
409ee44568
@ -27,7 +27,7 @@ class MazeEnv(gym.Env):
|
||||
self,
|
||||
model_cls: Type[AgentModel],
|
||||
maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask,
|
||||
top_down_view: float = False,
|
||||
include_position: bool = True,
|
||||
maze_height: float = 0.5,
|
||||
maze_size_scaling: float = 4.0,
|
||||
inner_reward_scaling: float = 1.0,
|
||||
@ -36,16 +36,11 @@ class MazeEnv(gym.Env):
|
||||
websock_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.t = 0 # time steps
|
||||
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
||||
|
||||
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
|
||||
tree = ET.parse(xml_path)
|
||||
worldbody = tree.find(".//worldbody")
|
||||
|
||||
self._maze_height = height = maze_height
|
||||
self._maze_size_scaling = size_scaling = maze_size_scaling
|
||||
self._inner_reward_scaling = inner_reward_scaling
|
||||
self.t = 0 # time steps
|
||||
self._observe_blocks = self._task.OBSERVE_BLOCKS
|
||||
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
||||
# Observe other objectives
|
||||
@ -94,6 +89,11 @@ class MazeEnv(gym.Env):
|
||||
# walls (immovable), chasms (fall), movable blocks
|
||||
self._view = np.zeros([5, 5, 3])
|
||||
|
||||
# Let's create MuJoCo XML
|
||||
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
|
||||
tree = ET.parse(xml_path)
|
||||
worldbody = tree.find(".//worldbody")
|
||||
|
||||
height_offset = 0.0
|
||||
if self.elevated:
|
||||
# Increase initial z-pos of ant.
|
||||
|
Loading…
Reference in New Issue
Block a user