Temporally exclude Block and BlockMaze
This commit is contained in:
parent
38f87fbb2d
commit
bc40d4cf9b
@ -1,12 +1,12 @@
|
||||
import gym
|
||||
|
||||
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
|
||||
MAZE_IDS = ["Maze", "Push", "Fall"] # TODO: Block, BlockMaze
|
||||
|
||||
|
||||
def _get_kwargs(maze_id: str) -> tuple:
|
||||
return {
|
||||
"observe_blocks": maze_id in ["Block", "BlockMaze"],
|
||||
"pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
|
||||
"put_spin_near_agent": maze_id in ["Block", "BlockMaze"],
|
||||
}
|
||||
|
||||
|
||||
|
@ -576,7 +576,7 @@ class MazeEnv(gym.Env):
|
||||
|
||||
|
||||
def _goal_fn(maze_id: str) -> callable:
|
||||
if maze_id in ["Maze", "Push"]:
|
||||
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
||||
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
|
||||
elif maze_id == "Fall":
|
||||
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
|
||||
@ -586,14 +586,14 @@ def _goal_fn(maze_id: str) -> callable:
|
||||
|
||||
def _reward_fn(maze_id: str, dense: str) -> callable:
|
||||
if dense:
|
||||
if maze_id in ["Maze", "Push"]:
|
||||
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
||||
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
|
||||
elif maze_id == "Fall":
|
||||
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||
else:
|
||||
if maze_id in ["Maze", "Push"]:
|
||||
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
||||
return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
|
||||
elif maze_id == "Fall":
|
||||
return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
|
||||
@ -602,7 +602,7 @@ def _reward_fn(maze_id: str, dense: str) -> callable:
|
||||
|
||||
|
||||
def _default_goal(maze_id: str) -> np.ndarray:
|
||||
if maze_id == "Maze":
|
||||
if maze_id == "Maze" or maze_id == "BlockMaze":
|
||||
return np.array([0.0, 8.0])
|
||||
elif maze_id == "Push":
|
||||
return np.array([0.0, 19.0])
|
||||
|
Loading…
Reference in New Issue
Block a user