Use top_down_view in Push and Fall
This commit is contained in:
parent
f0e4262c4d
commit
698d0acd94
@ -48,7 +48,7 @@ class MazeEnv(gym.Env):
|
|||||||
self.t = 0 # time steps
|
self.t = 0 # time steps
|
||||||
self._observe_blocks = self._task.OBSERVE_BLOCKS
|
self._observe_blocks = self._task.OBSERVE_BLOCKS
|
||||||
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
||||||
self._top_down_view = top_down_view
|
self._top_down_view = self._task.TOP_DOWN_VIEW
|
||||||
self._restitution_coef = restitution_coef
|
self._restitution_coef = restitution_coef
|
||||||
|
|
||||||
self._maze_structure = structure = self._task.create_maze()
|
self._maze_structure = structure = self._task.create_maze()
|
||||||
@ -248,6 +248,10 @@ class MazeEnv(gym.Env):
|
|||||||
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
|
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
|
||||||
self.observation_space = self._get_obs_space()
|
self.observation_space = self._get_obs_space()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_extended_obs(self) -> bool:
|
||||||
|
return self._top_down_view or self._observe_blocks
|
||||||
|
|
||||||
def get_ori(self) -> float:
|
def get_ori(self) -> float:
|
||||||
return self.wrapped_env.get_ori()
|
return self.wrapped_env.get_ori()
|
||||||
|
|
||||||
|
@ -52,6 +52,7 @@ class MazeTask(ABC):
|
|||||||
REWARD_THRESHOLD: float
|
REWARD_THRESHOLD: float
|
||||||
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0)
|
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0)
|
||||||
INNER_REWARD_SCALING: float = 0.01
|
INNER_REWARD_SCALING: float = 0.01
|
||||||
|
TOP_DOWN_VIEW: bool = False
|
||||||
OBSERVE_BLOCKS: bool = False
|
OBSERVE_BLOCKS: bool = False
|
||||||
PUT_SPIN_NEAR_AGENT: bool = False
|
PUT_SPIN_NEAR_AGENT: bool = False
|
||||||
|
|
||||||
@ -114,6 +115,8 @@ class DistRewardUMaze(GoalRewardUMaze, DistRewardMixIn):
|
|||||||
|
|
||||||
|
|
||||||
class GoalRewardPush(GoalRewardUMaze):
|
class GoalRewardPush(GoalRewardUMaze):
|
||||||
|
TOP_DOWN_VIEW = True
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, scale: float) -> None:
|
||||||
super().__init__(scale)
|
super().__init__(scale)
|
||||||
self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))]
|
self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))]
|
||||||
@ -135,6 +138,8 @@ class DistRewardPush(GoalRewardPush, DistRewardMixIn):
|
|||||||
|
|
||||||
|
|
||||||
class GoalRewardFall(GoalRewardUMaze):
|
class GoalRewardFall(GoalRewardUMaze):
|
||||||
|
TOP_DOWN_VIEW = True
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, scale: float) -> None:
|
||||||
super().__init__(scale)
|
super().__init__(scale)
|
||||||
self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))]
|
self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))]
|
||||||
|
@ -8,18 +8,22 @@ import mujoco_maze
|
|||||||
def test_ant_maze(maze_id):
|
def test_ant_maze(maze_id):
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
env = gym.make(f"Ant{maze_id}-v{i}")
|
env = gym.make(f"Ant{maze_id}-v{i}")
|
||||||
assert env.reset().shape == (30,)
|
s0 = env.reset()
|
||||||
s, _, _, _ = env.step(env.action_space.sample())
|
s, _, _, _ = env.step(env.action_space.sample())
|
||||||
assert s.shape == (30,)
|
if not env.unwrapped._top_down_view:
|
||||||
|
assert s0.shape == (30,)
|
||||||
|
assert s.shape == (30,)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
|
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
|
||||||
def test_point_maze(maze_id):
|
def test_point_maze(maze_id):
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
env = gym.make(f"Point{maze_id}-v{i}")
|
env = gym.make(f"Point{maze_id}-v{i}")
|
||||||
assert env.reset().shape == (7,)
|
s0 = env.reset()
|
||||||
s, _, _, _ = env.step(env.action_space.sample())
|
s, _, _, _ = env.step(env.action_space.sample())
|
||||||
assert s.shape == (7,)
|
if not env.unwrapped._top_down_view:
|
||||||
|
assert s0.shape == (7,)
|
||||||
|
assert s.shape == (7,)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("v", [0, 1])
|
@pytest.mark.parametrize("v", [0, 1])
|
||||||
|
Loading…
Reference in New Issue
Block a user