Temporally exclude Block and BlockMaze
This commit is contained in:
parent
38f87fbb2d
commit
bc40d4cf9b
@ -1,12 +1,12 @@
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
|
MAZE_IDS = ["Maze", "Push", "Fall"] # TODO: Block, BlockMaze
|
||||||
|
|
||||||
|
|
||||||
def _get_kwargs(maze_id: str) -> tuple:
|
def _get_kwargs(maze_id: str) -> tuple:
|
||||||
return {
|
return {
|
||||||
"observe_blocks": maze_id in ["Block", "BlockMaze"],
|
"observe_blocks": maze_id in ["Block", "BlockMaze"],
|
||||||
"pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
|
"put_spin_near_agent": maze_id in ["Block", "BlockMaze"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -576,7 +576,7 @@ class MazeEnv(gym.Env):
|
|||||||
|
|
||||||
|
|
||||||
def _goal_fn(maze_id: str) -> callable:
|
def _goal_fn(maze_id: str) -> callable:
|
||||||
if maze_id in ["Maze", "Push"]:
|
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
||||||
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
|
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
|
||||||
elif maze_id == "Fall":
|
elif maze_id == "Fall":
|
||||||
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
|
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
|
||||||
@ -586,14 +586,14 @@ def _goal_fn(maze_id: str) -> callable:
|
|||||||
|
|
||||||
def _reward_fn(maze_id: str, dense: str) -> callable:
|
def _reward_fn(maze_id: str, dense: str) -> callable:
|
||||||
if dense:
|
if dense:
|
||||||
if maze_id in ["Maze", "Push"]:
|
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
||||||
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
|
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
|
||||||
elif maze_id == "Fall":
|
elif maze_id == "Fall":
|
||||||
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
|
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||||
else:
|
else:
|
||||||
if maze_id in ["Maze", "Push"]:
|
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
||||||
return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
|
return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
|
||||||
elif maze_id == "Fall":
|
elif maze_id == "Fall":
|
||||||
return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
|
return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
|
||||||
@ -602,7 +602,7 @@ def _reward_fn(maze_id: str, dense: str) -> callable:
|
|||||||
|
|
||||||
|
|
||||||
def _default_goal(maze_id: str) -> np.ndarray:
|
def _default_goal(maze_id: str) -> np.ndarray:
|
||||||
if maze_id == "Maze":
|
if maze_id == "Maze" or maze_id == "BlockMaze":
|
||||||
return np.array([0.0, 8.0])
|
return np.array([0.0, 8.0])
|
||||||
elif maze_id == "Push":
|
elif maze_id == "Push":
|
||||||
return np.array([0.0, 19.0])
|
return np.array([0.0, 19.0])
|
||||||
|
Loading…
Reference in New Issue
Block a user