Introduce MazeCell and refactor
This commit is contained in:
parent
d4f588cf0c
commit
6bc5cf9739
@ -76,10 +76,10 @@ class MazeEnv(gym.Env):
|
|||||||
maze_id=self._maze_id
|
maze_id=self._maze_id
|
||||||
)
|
)
|
||||||
# Elevate the maze to allow for falling.
|
# Elevate the maze to allow for falling.
|
||||||
self.elevated = any(-1 in row for row in structure)
|
self.elevated = any(maze_env_utils.MazeCell.CHASM in row for row in structure)
|
||||||
# Are there any movable blocks?
|
# Are there any movable blocks?
|
||||||
self.blocks = any(
|
self.blocks = any(
|
||||||
any(maze_env_utils.can_move(r) for r in row) for row in structure
|
any(r.can_move() for r in row) for row in structure
|
||||||
)
|
)
|
||||||
|
|
||||||
torso_x, torso_y = self._find_robot()
|
torso_x, torso_y = self._find_robot()
|
||||||
@ -116,9 +116,9 @@ class MazeEnv(gym.Env):
|
|||||||
for i in range(len(structure)):
|
for i in range(len(structure)):
|
||||||
for j in range(len(structure[0])):
|
for j in range(len(structure[0])):
|
||||||
struct = structure[i][j]
|
struct = structure[i][j]
|
||||||
if struct == "r" and self._put_spin_near_agent:
|
if struct.is_robot() and self._put_spin_near_agent:
|
||||||
struct = maze_env_utils.Move.SpinXY
|
struct = maze_env_utils.Move.SpinXY
|
||||||
if self.elevated and struct not in [-1]:
|
if self.elevated and not struct.is_chasm():
|
||||||
# Create elevated platform.
|
# Create elevated platform.
|
||||||
x = j * size_scaling - torso_x
|
x = j * size_scaling - torso_x
|
||||||
y = i * size_scaling - torso_y
|
y = i * size_scaling - torso_y
|
||||||
@ -136,7 +136,7 @@ class MazeEnv(gym.Env):
|
|||||||
conaffinity="1",
|
conaffinity="1",
|
||||||
rgba="0.9 0.9 0.9 1",
|
rgba="0.9 0.9 0.9 1",
|
||||||
)
|
)
|
||||||
if struct == 1:
|
if struct.is_block():
|
||||||
# Unmovable block.
|
# Unmovable block.
|
||||||
# Offset all coordinates so that robot starts at the origin.
|
# Offset all coordinates so that robot starts at the origin.
|
||||||
x = j * size_scaling - torso_x
|
x = j * size_scaling - torso_x
|
||||||
@ -155,14 +155,14 @@ class MazeEnv(gym.Env):
|
|||||||
conaffinity="1",
|
conaffinity="1",
|
||||||
rgba="0.4 0.4 0.4 1",
|
rgba="0.4 0.4 0.4 1",
|
||||||
)
|
)
|
||||||
elif maze_env_utils.can_move(struct):
|
elif struct.can_move():
|
||||||
# Movable block.
|
# Movable block.
|
||||||
# The "falling" blocks are shrunk slightly and increased in mass to
|
# The "falling" blocks are shrunk slightly and increased in mass to
|
||||||
# ensure it can fall easily through a gap in the platform blocks.
|
# ensure it can fall easily through a gap in the platform blocks.
|
||||||
name = "movable_%d_%d" % (i, j)
|
name = "movable_%d_%d" % (i, j)
|
||||||
self.movable_blocks.append((name, struct))
|
self.movable_blocks.append((name, struct))
|
||||||
falling = maze_env_utils.can_move_z(struct)
|
falling = struct.can_move_z()
|
||||||
spinning = maze_env_utils.can_spin(struct)
|
spinning = struct.can_spin()
|
||||||
shrink = 0.1 if spinning else 0.99 if falling else 1.0
|
shrink = 0.1 if spinning else 0.99 if falling else 1.0
|
||||||
height_shrink = 0.1 if spinning else 1.0
|
height_shrink = 0.1 if spinning else 1.0
|
||||||
x = (
|
x = (
|
||||||
@ -192,7 +192,7 @@ class MazeEnv(gym.Env):
|
|||||||
conaffinity="1",
|
conaffinity="1",
|
||||||
rgba="0.9 0.1 0.1 1",
|
rgba="0.9 0.1 0.1 1",
|
||||||
)
|
)
|
||||||
if maze_env_utils.can_move_x(struct):
|
if struct.can_move_x():
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
movable_body,
|
movable_body,
|
||||||
"joint",
|
"joint",
|
||||||
@ -206,7 +206,7 @@ class MazeEnv(gym.Env):
|
|||||||
pos="0 0 0",
|
pos="0 0 0",
|
||||||
type="slide",
|
type="slide",
|
||||||
)
|
)
|
||||||
if maze_env_utils.can_move_y(struct):
|
if struct.can_move_y():
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
movable_body,
|
movable_body,
|
||||||
"joint",
|
"joint",
|
||||||
@ -220,7 +220,7 @@ class MazeEnv(gym.Env):
|
|||||||
pos="0 0 0",
|
pos="0 0 0",
|
||||||
type="slide",
|
type="slide",
|
||||||
)
|
)
|
||||||
if maze_env_utils.can_move_z(struct):
|
if struct.can_move_z():
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
movable_body,
|
movable_body,
|
||||||
"joint",
|
"joint",
|
||||||
@ -234,7 +234,7 @@ class MazeEnv(gym.Env):
|
|||||||
pos="0 0 0",
|
pos="0 0 0",
|
||||||
type="slide",
|
type="slide",
|
||||||
)
|
)
|
||||||
if maze_env_utils.can_spin(struct):
|
if struct.can_spin():
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
movable_body,
|
movable_body,
|
||||||
"joint",
|
"joint",
|
||||||
@ -353,13 +353,13 @@ class MazeEnv(gym.Env):
|
|||||||
# Draw immovable blocks and chasms.
|
# Draw immovable blocks and chasms.
|
||||||
for i in range(len(structure)):
|
for i in range(len(structure)):
|
||||||
for j in range(len(structure[0])):
|
for j in range(len(structure[0])):
|
||||||
if structure[i][j] == 1: # Wall.
|
if structure[i][j].is_block(): # Wall.
|
||||||
update_view(
|
update_view(
|
||||||
j * size_scaling - self._init_torso_x,
|
j * size_scaling - self._init_torso_x,
|
||||||
i * size_scaling - self._init_torso_y,
|
i * size_scaling - self._init_torso_y,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
if structure[i][j] == -1: # Chasm.
|
if structure[i][j].is_chasm(): # Chasm.
|
||||||
update_view(
|
update_view(
|
||||||
j * size_scaling - self._init_torso_x,
|
j * size_scaling - self._init_torso_x,
|
||||||
i * size_scaling - self._init_torso_y,
|
i * size_scaling - self._init_torso_y,
|
||||||
@ -387,7 +387,7 @@ class MazeEnv(gym.Env):
|
|||||||
# block or drop-off.
|
# block or drop-off.
|
||||||
for i in range(len(structure)):
|
for i in range(len(structure)):
|
||||||
for j in range(len(structure[0])):
|
for j in range(len(structure[0])):
|
||||||
if structure[i][j] in [1, -1]: # There's a wall or drop-off.
|
if structure[i][j].is_wall_or_chasm(): # There's a wall or drop-off.
|
||||||
cx = j * size_scaling - self._init_torso_x
|
cx = j * size_scaling - self._init_torso_x
|
||||||
cy = i * size_scaling - self._init_torso_y
|
cy = i * size_scaling - self._init_torso_y
|
||||||
x1 = cx - 0.5 * size_scaling
|
x1 = cx - 0.5 * size_scaling
|
||||||
@ -456,7 +456,7 @@ class MazeEnv(gym.Env):
|
|||||||
idx = 0 # Wall
|
idx = 0 # Wall
|
||||||
elif seg_type == -1:
|
elif seg_type == -1:
|
||||||
idx = 1 # Drop-off
|
idx = 1 # Drop-off
|
||||||
elif maze_env_utils.can_move(seg_type):
|
elif seg_type.can_move():
|
||||||
idx == 2 # Block
|
idx == 2 # Block
|
||||||
sr = self._sensor_range
|
sr = self._sensor_range
|
||||||
if first_seg["distance"] <= sr:
|
if first_seg["distance"] <= sr:
|
||||||
@ -516,7 +516,7 @@ class MazeEnv(gym.Env):
|
|||||||
structure = self._maze_structure
|
structure = self._maze_structure
|
||||||
size_scaling = self._maze_size_scaling
|
size_scaling = self._maze_size_scaling
|
||||||
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
|
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
|
||||||
if structure[i][j] == "r":
|
if structure[i][j].is_robot():
|
||||||
return j * size_scaling, i * size_scaling
|
return j * size_scaling, i * size_scaling
|
||||||
raise ValueError("No robot in maze specification.")
|
raise ValueError("No robot in maze specification.")
|
||||||
|
|
||||||
@ -525,7 +525,7 @@ class MazeEnv(gym.Env):
|
|||||||
size_scaling = self._maze_size_scaling
|
size_scaling = self._maze_size_scaling
|
||||||
coords = []
|
coords = []
|
||||||
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
|
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
|
||||||
if structure[i][j] == "r":
|
if structure[i][j].is_robot():
|
||||||
coords.append((j * size_scaling, i * size_scaling))
|
coords.append((j * size_scaling, i * size_scaling))
|
||||||
return coords
|
return coords
|
||||||
|
|
||||||
|
@ -14,12 +14,20 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Adapted from rllab maze_env_utils.py."""
|
"""Adapted from rllab maze_env_utils.py."""
|
||||||
|
from enum import Enum
|
||||||
import itertools as it
|
import itertools as it
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Move:
|
class MazeCell(Enum):
|
||||||
|
# Robot: Start position
|
||||||
|
ROBOT = -1
|
||||||
|
# Blocks
|
||||||
|
EMPTY = 0
|
||||||
|
BLOCK = 1
|
||||||
|
CHASM = 2
|
||||||
|
# Moves
|
||||||
X = 11
|
X = 11
|
||||||
Y = 12
|
Y = 12
|
||||||
Z = 13
|
Z = 13
|
||||||
@ -29,69 +37,88 @@ class Move:
|
|||||||
XYZ = 17
|
XYZ = 17
|
||||||
SpinXY = 18
|
SpinXY = 18
|
||||||
|
|
||||||
|
def is_block(self) -> bool:
|
||||||
|
return self == self.BLOCK
|
||||||
|
|
||||||
def can_move_x(movable):
|
def is_chasm(self) -> bool:
|
||||||
return movable in [Move.X, Move.XY, Move.XZ, Move.XYZ, Move.SpinXY]
|
return self == self.CHASM
|
||||||
|
|
||||||
|
def is_robot(self) -> bool:
|
||||||
|
return self == self.ROBOT
|
||||||
|
|
||||||
def can_move_y(movable):
|
def is_wall_or_chasm(self) -> bool:
|
||||||
return movable in [Move.Y, Move.XY, Move.YZ, Move.XYZ, Move.SpinXY]
|
return self in [self.BLOCK, self.CHASM]
|
||||||
|
|
||||||
|
def can_move_x(self) -> bool:
|
||||||
|
return self in [
|
||||||
|
self.X,
|
||||||
|
self.XY,
|
||||||
|
self.XZ,
|
||||||
|
self.XYZ,
|
||||||
|
self.SpinXY,
|
||||||
|
]
|
||||||
|
|
||||||
def can_move_z(movable):
|
def can_move_y(self):
|
||||||
return movable in [Move.Z, Move.XZ, Move.YZ, Move.XYZ]
|
return self in [
|
||||||
|
self.Y,
|
||||||
|
self.XY,
|
||||||
|
self.YZ,
|
||||||
|
self.XYZ,
|
||||||
|
self.SpinXY,
|
||||||
|
]
|
||||||
|
|
||||||
|
def can_move_z(self):
|
||||||
|
return self in [self.Z, self.XZ, self.YZ, self.XYZ]
|
||||||
|
|
||||||
def can_spin(movable):
|
def can_spin(self):
|
||||||
return movable in [Move.SpinXY]
|
return self == self.SpinXY
|
||||||
|
|
||||||
|
def can_move(self):
|
||||||
def can_move(movable):
|
return self.can_move_x() or self.can_move_y() or self.can_move_z()
|
||||||
return can_move_x(movable) or can_move_y(movable) or can_move_z(movable)
|
|
||||||
|
|
||||||
|
|
||||||
def construct_maze(maze_id="Maze"):
|
def construct_maze(maze_id="Maze"):
|
||||||
R = "r"
|
E, B, C, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.CHASM, MazeCell.ROBOT
|
||||||
if maze_id == "Maze":
|
if maze_id == "Maze":
|
||||||
structure = [
|
structure = [
|
||||||
[1, 1, 1, 1, 1],
|
[B, B, B, B, B],
|
||||||
[1, R, 0, 0, 1],
|
[B, R, E, E, B],
|
||||||
[1, 1, 1, 0, 1],
|
[B, B, B, E, B],
|
||||||
[1, 0, 0, 0, 1],
|
[B, E, E, E, B],
|
||||||
[1, 1, 1, 1, 1],
|
[B, B, B, B, B],
|
||||||
]
|
]
|
||||||
elif maze_id == "Push":
|
elif maze_id == "Push":
|
||||||
structure = [
|
structure = [
|
||||||
[1, 1, 1, 1, 1],
|
[B, B, B, B, B],
|
||||||
[1, 0, R, 1, 1],
|
[B, E, R, B, B],
|
||||||
[1, 0, Move.XY, 0, 1],
|
[B, E, MazeCell.XY, E, B],
|
||||||
[1, 1, 0, 1, 1],
|
[B, B, E, B, B],
|
||||||
[1, 1, 1, 1, 1],
|
[B, B, B, B, B],
|
||||||
]
|
]
|
||||||
elif maze_id == "Fall":
|
elif maze_id == "Fall":
|
||||||
structure = [
|
structure = [
|
||||||
[1, 1, 1, 1],
|
[B, B, B, B],
|
||||||
[1, R, 0, 1],
|
[B, R, E, B],
|
||||||
[1, 0, Move.YZ, 1],
|
[B, E, MazeCell.YZ, B],
|
||||||
[1, -1, -1, 1],
|
[B, C, C, B],
|
||||||
[1, 0, 0, 1],
|
[B, E, E, B],
|
||||||
[1, 1, 1, 1],
|
[B, B, B, B],
|
||||||
]
|
]
|
||||||
elif maze_id == "Block":
|
elif maze_id == "Block":
|
||||||
structure = [
|
structure = [
|
||||||
[1, 1, 1, 1, 1],
|
[B, B, B, B, B],
|
||||||
[1, R, 0, 0, 1],
|
[B, R, E, E, B],
|
||||||
[1, 0, 0, 0, 1],
|
[B, E, E, E, B],
|
||||||
[1, 0, 0, 0, 1],
|
[B, E, E, E, B],
|
||||||
[1, 1, 1, 1, 1],
|
[B, B, B, B, B],
|
||||||
]
|
]
|
||||||
elif maze_id == "BlockMaze":
|
elif maze_id == "BlockMaze":
|
||||||
structure = [
|
structure = [
|
||||||
[1, 1, 1, 1],
|
[B, B, B, B],
|
||||||
[1, R, 0, 1],
|
[B, R, E, B],
|
||||||
[1, 1, 0, 1],
|
[B, B, E, B],
|
||||||
[1, 0, 0, 1],
|
[B, E, E, B],
|
||||||
[1, 1, 1, 1],
|
[B, B, B, B],
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The provided MazeId %s is not recognized" % maze_id)
|
raise NotImplementedError("The provided MazeId %s is not recognized" % maze_id)
|
||||||
@ -115,7 +142,7 @@ class Collision:
|
|||||||
def is_block(pos) -> bool:
|
def is_block(pos) -> bool:
|
||||||
i, j = pos
|
i, j = pos
|
||||||
if 0 <= i < h and 0 <= j < w:
|
if 0 <= i < h and 0 <= j < w:
|
||||||
return structure[i][j] == 1
|
return structure[i][j].is_block()
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -123,7 +150,7 @@ class Collision:
|
|||||||
return self.OFFSET[is_block(pos + self.ARROUND[index])]
|
return self.OFFSET[is_block(pos + self.ARROUND[index])]
|
||||||
|
|
||||||
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
|
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
|
||||||
if structure[i][j] != 1:
|
if not structure[i][j].is_block():
|
||||||
continue
|
continue
|
||||||
pos = np.array([i, j])
|
pos = np.array([i, j])
|
||||||
y_base = i * size_scaling - torso_y
|
y_base = i * size_scaling - torso_y
|
||||||
|
@ -17,6 +17,8 @@ mujoco-py = ">=1.5"
|
|||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "^3.0"
|
pytest = "^3.0"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
test = "pytest:main"
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
|
Loading…
Reference in New Issue
Block a user