Use top_down_view in Push and Fall

This commit is contained in:
kngwyu 2020-09-21 13:27:41 +09:00
parent f0e4262c4d
commit 698d0acd94
3 changed files with 18 additions and 5 deletions

View File

@ -48,7 +48,7 @@ class MazeEnv(gym.Env):
self.t = 0 # time steps self.t = 0 # time steps
self._observe_blocks = self._task.OBSERVE_BLOCKS self._observe_blocks = self._task.OBSERVE_BLOCKS
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
self._top_down_view = top_down_view self._top_down_view = self._task.TOP_DOWN_VIEW
self._restitution_coef = restitution_coef self._restitution_coef = restitution_coef
self._maze_structure = structure = self._task.create_maze() self._maze_structure = structure = self._task.create_maze()
@ -248,6 +248,10 @@ class MazeEnv(gym.Env):
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
self.observation_space = self._get_obs_space() self.observation_space = self._get_obs_space()
@property
def has_extended_obs(self) -> bool:
return self._top_down_view or self._observe_blocks
def get_ori(self) -> float: def get_ori(self) -> float:
return self.wrapped_env.get_ori() return self.wrapped_env.get_ori()

View File

@ -52,6 +52,7 @@ class MazeTask(ABC):
REWARD_THRESHOLD: float REWARD_THRESHOLD: float
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0) MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0)
INNER_REWARD_SCALING: float = 0.01 INNER_REWARD_SCALING: float = 0.01
TOP_DOWN_VIEW: bool = False
OBSERVE_BLOCKS: bool = False OBSERVE_BLOCKS: bool = False
PUT_SPIN_NEAR_AGENT: bool = False PUT_SPIN_NEAR_AGENT: bool = False
@ -114,6 +115,8 @@ class DistRewardUMaze(GoalRewardUMaze, DistRewardMixIn):
class GoalRewardPush(GoalRewardUMaze): class GoalRewardPush(GoalRewardUMaze):
TOP_DOWN_VIEW = True
def __init__(self, scale: float) -> None: def __init__(self, scale: float) -> None:
super().__init__(scale) super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))] self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))]
@ -135,6 +138,8 @@ class DistRewardPush(GoalRewardPush, DistRewardMixIn):
class GoalRewardFall(GoalRewardUMaze): class GoalRewardFall(GoalRewardUMaze):
TOP_DOWN_VIEW = True
def __init__(self, scale: float) -> None: def __init__(self, scale: float) -> None:
super().__init__(scale) super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))] self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))]

View File

@ -8,18 +8,22 @@ import mujoco_maze
def test_ant_maze(maze_id): def test_ant_maze(maze_id):
for i in range(2): for i in range(2):
env = gym.make(f"Ant{maze_id}-v{i}") env = gym.make(f"Ant{maze_id}-v{i}")
assert env.reset().shape == (30,) s0 = env.reset()
s, _, _, _ = env.step(env.action_space.sample()) s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (30,) if not env.unwrapped._top_down_view:
assert s0.shape == (30,)
assert s.shape == (30,)
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys()) @pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
def test_point_maze(maze_id): def test_point_maze(maze_id):
for i in range(2): for i in range(2):
env = gym.make(f"Point{maze_id}-v{i}") env = gym.make(f"Point{maze_id}-v{i}")
assert env.reset().shape == (7,) s0 = env.reset()
s, _, _, _ = env.step(env.action_space.sample()) s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (7,) if not env.unwrapped._top_down_view:
assert s0.shape == (7,)
assert s.shape == (7,)
@pytest.mark.parametrize("v", [0, 1]) @pytest.mark.parametrize("v", [0, 1])