Add TRoom

This commit is contained in:
kngwyu 2020-09-17 01:27:38 +09:00
parent 4a195d9848
commit 8f630b58c2
3 changed files with 50 additions and 2 deletions

View File

@ -32,10 +32,11 @@ class MazeEnv(gym.Env):
maze_size_scaling: float = 4.0, maze_size_scaling: float = 4.0,
inner_reward_scaling: float = 1.0, inner_reward_scaling: float = 1.0,
restitution_coef: float = 0.8, restitution_coef: float = 0.8,
task_kwargs: dict = {},
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
self._task = maze_task(maze_size_scaling) self._task = maze_task(maze_size_scaling, **task_kwargs)
xml_path = os.path.join(MODEL_DIR, model_cls.FILE) xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
tree = ET.parse(xml_path) tree = ET.parse(xml_path)

View File

@ -2,7 +2,7 @@
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, NamedTuple, Type from typing import Dict, List, NamedTuple, Tuple, Type
import numpy as np import numpy as np
@ -237,6 +237,44 @@ class SubGoal4Rooms(GoalReward4Rooms):
] ]
class GoalRewardTRoom(MazeTask):
REWARD_THRESHOLD: float = 0.9
MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0)
def __init__(
self,
scale: float,
goals: List[Tuple[float, float]] = [(2.0, -4.0)],
) -> None:
super().__init__(scale)
self.goals = []
for x, y in goals:
self.goals.append(MazeGoal(np.array([x * scale, y * scale])))
def reward(self, obs: np.ndarray) -> float:
for goal in self.goals:
if goal.neighbor(obs):
return goal.reward_scale
return -0.0001
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
return [
[B, B, B, B, B, B, B],
[B, E, E, E, E, E, B],
[B, E, E, B, E, E, B],
[B, E, E, B, E, E, B],
[B, E, B, B, B, E, B],
[B, E, E, R, E, E, B],
[B, B, B, B, B, B, B],
]
class DistRewardTRoom(GoalRewardTRoom, DistRewardMixIn):
pass
class TaskRegistry: class TaskRegistry:
REGISTRY: Dict[str, List[Type[MazeTask]]] = { REGISTRY: Dict[str, List[Type[MazeTask]]] = {
"UMaze": [DistRewardUMaze, GoalRewardUMaze], "UMaze": [DistRewardUMaze, GoalRewardUMaze],
@ -244,6 +282,7 @@ class TaskRegistry:
"Fall": [DistRewardFall, GoalRewardFall], "Fall": [DistRewardFall, GoalRewardFall],
"2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms],
"4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms],
"TRoom": [DistRewardTRoom, GoalRewardTRoom],
} }
@staticmethod @staticmethod

View File

@ -20,3 +20,11 @@ def test_point_maze(maze_id):
assert env.reset().shape == (7,) assert env.reset().shape == (7,)
s, _, _, _ = env.step(env.action_space.sample()) s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (7,) assert s.shape == (7,)
@pytest.mark.parametrize("v", [0, 1])
def test_maze_args(v):
env = gym.make(f"PointTRoom-v{v}", task_kwargs={"goals": [(-2.0, 4.0)]})
assert env.reset().shape == (7,)
s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (7,)