From bc40d4cf9bb6c80337c8c9882b6c98a451b6748f Mon Sep 17 00:00:00 2001 From: kngwyu Date: Sat, 30 May 2020 00:14:24 +0900 Subject: [PATCH] Temporally exclude Block and BlockMaze --- mujoco_maze/__init__.py | 4 ++-- mujoco_maze/maze_env.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py index 55bf2e4..67e0e0a 100644 --- a/mujoco_maze/__init__.py +++ b/mujoco_maze/__init__.py @@ -1,12 +1,12 @@ import gym -MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"] +MAZE_IDS = ["Maze", "Push", "Fall"] # TODO: Block, BlockMaze def _get_kwargs(maze_id: str) -> tuple: return { "observe_blocks": maze_id in ["Block", "BlockMaze"], - "pin_spin_near_agent": maze_id in ["Block", "BlockMaze"], + "put_spin_near_agent": maze_id in ["Block", "BlockMaze"], } diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index cb38edb..719b5c5 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -576,7 +576,7 @@ class MazeEnv(gym.Env): def _goal_fn(maze_id: str) -> callable: - if maze_id in ["Maze", "Push"]: + if maze_id in ["Maze", "Push", "BlockMaze"]: return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6 elif maze_id == "Fall": return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6 @@ -586,14 +586,14 @@ def _goal_fn(maze_id: str) -> callable: def _reward_fn(maze_id: str, dense: str) -> callable: if dense: - if maze_id in ["Maze", "Push"]: + if maze_id in ["Maze", "Push", "BlockMaze"]: return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5 elif maze_id == "Fall": return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5 else: raise NotImplementedError(f"Unknown maze id: {maze_id}") else: - if maze_id in ["Maze", "Push"]: + if maze_id in ["Maze", "Push", "BlockMaze"]: return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0 elif maze_id == "Fall": return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0 @@ -602,7 +602,7 @@ def _reward_fn(maze_id: str, dense: str) -> callable: def _default_goal(maze_id: str) -> np.ndarray: - if maze_id == "Maze": + if maze_id == "Maze" or maze_id == "BlockMaze": return np.array([0.0, 8.0]) elif maze_id == "Push": return np.array([0.0, 19.0])