Use top_down_view in Push and Fall
This commit is contained in:
parent
f0e4262c4d
commit
698d0acd94
@ -48,7 +48,7 @@ class MazeEnv(gym.Env):
|
||||
self.t = 0 # time steps
|
||||
self._observe_blocks = self._task.OBSERVE_BLOCKS
|
||||
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
||||
self._top_down_view = top_down_view
|
||||
self._top_down_view = self._task.TOP_DOWN_VIEW
|
||||
self._restitution_coef = restitution_coef
|
||||
|
||||
self._maze_structure = structure = self._task.create_maze()
|
||||
@ -248,6 +248,10 @@ class MazeEnv(gym.Env):
|
||||
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
|
||||
self.observation_space = self._get_obs_space()
|
||||
|
||||
@property
|
||||
def has_extended_obs(self) -> bool:
|
||||
return self._top_down_view or self._observe_blocks
|
||||
|
||||
def get_ori(self) -> float:
|
||||
return self.wrapped_env.get_ori()
|
||||
|
||||
|
@ -52,6 +52,7 @@ class MazeTask(ABC):
|
||||
REWARD_THRESHOLD: float
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0)
|
||||
INNER_REWARD_SCALING: float = 0.01
|
||||
TOP_DOWN_VIEW: bool = False
|
||||
OBSERVE_BLOCKS: bool = False
|
||||
PUT_SPIN_NEAR_AGENT: bool = False
|
||||
|
||||
@ -114,6 +115,8 @@ class DistRewardUMaze(GoalRewardUMaze, DistRewardMixIn):
|
||||
|
||||
|
||||
class GoalRewardPush(GoalRewardUMaze):
|
||||
TOP_DOWN_VIEW = True
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))]
|
||||
@ -135,6 +138,8 @@ class DistRewardPush(GoalRewardPush, DistRewardMixIn):
|
||||
|
||||
|
||||
class GoalRewardFall(GoalRewardUMaze):
|
||||
TOP_DOWN_VIEW = True
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))]
|
||||
|
@ -8,18 +8,22 @@ import mujoco_maze
|
||||
def test_ant_maze(maze_id):
|
||||
for i in range(2):
|
||||
env = gym.make(f"Ant{maze_id}-v{i}")
|
||||
assert env.reset().shape == (30,)
|
||||
s0 = env.reset()
|
||||
s, _, _, _ = env.step(env.action_space.sample())
|
||||
assert s.shape == (30,)
|
||||
if not env.unwrapped._top_down_view:
|
||||
assert s0.shape == (30,)
|
||||
assert s.shape == (30,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
|
||||
def test_point_maze(maze_id):
|
||||
for i in range(2):
|
||||
env = gym.make(f"Point{maze_id}-v{i}")
|
||||
assert env.reset().shape == (7,)
|
||||
s0 = env.reset()
|
||||
s, _, _, _ = env.step(env.action_space.sample())
|
||||
assert s.shape == (7,)
|
||||
if not env.unwrapped._top_down_view:
|
||||
assert s0.shape == (7,)
|
||||
assert s.shape == (7,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("v", [0, 1])
|
||||
|
Loading…
Reference in New Issue
Block a user