Merge pull request #2 from kngwyu/troom

TRoom
This commit is contained in:
Yuji Kanagawa 2020-09-18 22:28:25 +09:00 committed by GitHub
commit f0e4262c4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 2 deletions

View File

@ -17,6 +17,13 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
MujocoEnv.__init__(self, file_path, frame_skip) MujocoEnv.__init__(self, file_path, frame_skip)
EzPickle.__init__(self) EzPickle.__init__(self)
def close(self):
if self.viewer is not None and hasattr(self.viewer, "window"):
import glfw
glfw.destroy_window(self.viewer.window)
super().close()
@abstractmethod @abstractmethod
def _get_obs(self) -> np.ndarray: def _get_obs(self) -> np.ndarray:
"""Returns the observation from the model. """Returns the observation from the model.

View File

@ -32,10 +32,11 @@ class MazeEnv(gym.Env):
maze_size_scaling: float = 4.0, maze_size_scaling: float = 4.0,
inner_reward_scaling: float = 1.0, inner_reward_scaling: float = 1.0,
restitution_coef: float = 0.8, restitution_coef: float = 0.8,
task_kwargs: dict = {},
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
self._task = maze_task(maze_size_scaling) self._task = maze_task(maze_size_scaling, **task_kwargs)
xml_path = os.path.join(MODEL_DIR, model_cls.FILE) xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
tree = ET.parse(xml_path) tree = ET.parse(xml_path)
@ -452,3 +453,6 @@ class MazeEnv(gym.Env):
done = self._task.termination(next_obs) done = self._task.termination(next_obs)
info["position"] = self.wrapped_env.get_xy() info["position"] = self.wrapped_env.get_xy()
return next_obs, inner_reward + outer_reward, done, info return next_obs, inner_reward + outer_reward, done, info
def close(self):
self.wrapped_env.close()

View File

@ -2,7 +2,7 @@
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, NamedTuple, Type from typing import Dict, List, NamedTuple, Tuple, Type
import numpy as np import numpy as np
@ -237,6 +237,44 @@ class SubGoal4Rooms(GoalReward4Rooms):
] ]
class GoalRewardTRoom(MazeTask):
REWARD_THRESHOLD: float = 0.9
MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0)
def __init__(
self,
scale: float,
goals: List[Tuple[float, float]] = [(2.0, -4.0)],
) -> None:
super().__init__(scale)
self.goals = []
for x, y in goals:
self.goals.append(MazeGoal(np.array([x * scale, y * scale])))
def reward(self, obs: np.ndarray) -> float:
for goal in self.goals:
if goal.neighbor(obs):
return goal.reward_scale
return -0.0001
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
return [
[B, B, B, B, B, B, B],
[B, E, E, E, E, E, B],
[B, E, E, B, E, E, B],
[B, E, E, B, E, E, B],
[B, E, B, B, B, E, B],
[B, E, E, R, E, E, B],
[B, B, B, B, B, B, B],
]
class DistRewardTRoom(GoalRewardTRoom, DistRewardMixIn):
pass
class TaskRegistry: class TaskRegistry:
REGISTRY: Dict[str, List[Type[MazeTask]]] = { REGISTRY: Dict[str, List[Type[MazeTask]]] = {
"UMaze": [DistRewardUMaze, GoalRewardUMaze], "UMaze": [DistRewardUMaze, GoalRewardUMaze],
@ -244,6 +282,7 @@ class TaskRegistry:
"Fall": [DistRewardFall, GoalRewardFall], "Fall": [DistRewardFall, GoalRewardFall],
"2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms],
"4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms],
"TRoom": [DistRewardTRoom, GoalRewardTRoom],
} }
@staticmethod @staticmethod

View File

@ -20,3 +20,11 @@ def test_point_maze(maze_id):
assert env.reset().shape == (7,) assert env.reset().shape == (7,)
s, _, _, _ = env.step(env.action_space.sample()) s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (7,) assert s.shape == (7,)
@pytest.mark.parametrize("v", [0, 1])
def test_maze_args(v):
env = gym.make(f"PointTRoom-v{v}", task_kwargs={"goals": [(-2.0, -4.0)]})
assert env.reset().shape == (7,)
s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (7,)