Temporally exclude Block and BlockMaze

This commit is contained in:
kngwyu 2020-05-30 00:14:24 +09:00
parent 38f87fbb2d
commit bc40d4cf9b
2 changed files with 6 additions and 6 deletions

View File

@ -1,12 +1,12 @@
import gym
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
MAZE_IDS = ["Maze", "Push", "Fall"] # TODO: Block, BlockMaze
def _get_kwargs(maze_id: str) -> tuple:
return {
"observe_blocks": maze_id in ["Block", "BlockMaze"],
"pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
"put_spin_near_agent": maze_id in ["Block", "BlockMaze"],
}

View File

@ -576,7 +576,7 @@ class MazeEnv(gym.Env):
def _goal_fn(maze_id: str) -> callable:
if maze_id in ["Maze", "Push"]:
if maze_id in ["Maze", "Push", "BlockMaze"]:
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
elif maze_id == "Fall":
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
@ -586,14 +586,14 @@ def _goal_fn(maze_id: str) -> callable:
def _reward_fn(maze_id: str, dense: str) -> callable:
if dense:
if maze_id in ["Maze", "Push"]:
if maze_id in ["Maze", "Push", "BlockMaze"]:
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
elif maze_id == "Fall":
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")
else:
if maze_id in ["Maze", "Push"]:
if maze_id in ["Maze", "Push", "BlockMaze"]:
return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
elif maze_id == "Fall":
return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
@ -602,7 +602,7 @@ def _reward_fn(maze_id: str, dense: str) -> callable:
def _default_goal(maze_id: str) -> np.ndarray:
if maze_id == "Maze":
if maze_id == "Maze" or maze_id == "BlockMaze":
return np.array([0.0, 8.0])
elif maze_id == "Push":
return np.array([0.0, 19.0])