Use top_down_view in Push and Fall

This commit is contained in:
kngwyu 2020-09-21 13:27:41 +09:00
parent f0e4262c4d
commit 698d0acd94
3 changed files with 18 additions and 5 deletions

View File

@ -48,7 +48,7 @@ class MazeEnv(gym.Env):
self.t = 0 # time steps
self._observe_blocks = self._task.OBSERVE_BLOCKS
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
self._top_down_view = top_down_view
self._top_down_view = self._task.TOP_DOWN_VIEW
self._restitution_coef = restitution_coef
self._maze_structure = structure = self._task.create_maze()
@ -248,6 +248,10 @@ class MazeEnv(gym.Env):
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
self.observation_space = self._get_obs_space()
@property
def has_extended_obs(self) -> bool:
return self._top_down_view or self._observe_blocks
def get_ori(self) -> float:
return self.wrapped_env.get_ori()

View File

@ -52,6 +52,7 @@ class MazeTask(ABC):
REWARD_THRESHOLD: float
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0)
INNER_REWARD_SCALING: float = 0.01
TOP_DOWN_VIEW: bool = False
OBSERVE_BLOCKS: bool = False
PUT_SPIN_NEAR_AGENT: bool = False
@ -114,6 +115,8 @@ class DistRewardUMaze(GoalRewardUMaze, DistRewardMixIn):
class GoalRewardPush(GoalRewardUMaze):
TOP_DOWN_VIEW = True
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))]
@ -135,6 +138,8 @@ class DistRewardPush(GoalRewardPush, DistRewardMixIn):
class GoalRewardFall(GoalRewardUMaze):
TOP_DOWN_VIEW = True
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))]

View File

@ -8,8 +8,10 @@ import mujoco_maze
def test_ant_maze(maze_id):
for i in range(2):
env = gym.make(f"Ant{maze_id}-v{i}")
assert env.reset().shape == (30,)
s0 = env.reset()
s, _, _, _ = env.step(env.action_space.sample())
if not env.unwrapped._top_down_view:
assert s0.shape == (30,)
assert s.shape == (30,)
@ -17,8 +19,10 @@ def test_ant_maze(maze_id):
def test_point_maze(maze_id):
for i in range(2):
env = gym.make(f"Point{maze_id}-v{i}")
assert env.reset().shape == (7,)
s0 = env.reset()
s, _, _, _ = env.step(env.action_space.sample())
if not env.unwrapped._top_down_view:
assert s0.shape == (7,)
assert s.shape == (7,)