From 8f630b58c20931bc289b1702770a436ef3946e55 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 17 Sep 2020 01:27:38 +0900 Subject: [PATCH 1/2] Add TRoom --- mujoco_maze/maze_env.py | 3 ++- mujoco_maze/maze_task.py | 41 +++++++++++++++++++++++++++++++++++++++- tests/test_envs.py | 8 ++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 52df046..d9fe7e8 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -32,10 +32,11 @@ class MazeEnv(gym.Env): maze_size_scaling: float = 4.0, inner_reward_scaling: float = 1.0, restitution_coef: float = 0.8, + task_kwargs: dict = {}, *args, **kwargs, ) -> None: - self._task = maze_task(maze_size_scaling) + self._task = maze_task(maze_size_scaling, **task_kwargs) xml_path = os.path.join(MODEL_DIR, model_cls.FILE) tree = ET.parse(xml_path) diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 74ed313..a3feda3 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -2,7 +2,7 @@ """ from abc import ABC, abstractmethod -from typing import Dict, List, NamedTuple, Type +from typing import Dict, List, NamedTuple, Tuple, Type import numpy as np @@ -237,6 +237,44 @@ class SubGoal4Rooms(GoalReward4Rooms): ] +class GoalRewardTRoom(MazeTask): + REWARD_THRESHOLD: float = 0.9 + MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0) + + def __init__( + self, + scale: float, + goals: List[Tuple[float, float]] = [(2.0, -4.0)], + ) -> None: + super().__init__(scale) + self.goals = [] + for x, y in goals: + self.goals.append(MazeGoal(np.array([x * scale, y * scale]))) + + def reward(self, obs: np.ndarray) -> float: + for goal in self.goals: + if goal.neighbor(obs): + return goal.reward_scale + return -0.0001 + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT + return [ + [B, B, B, B, B, B, B], + [B, E, E, E, E, E, B], + [B, E, E, B, E, E, B], + [B, E, E, B, E, E, B], + [B, E, B, B, B, E, B], + [B, E, E, R, E, E, B], + [B, B, B, B, B, B, B], + ] + + +class DistRewardTRoom(GoalRewardTRoom, DistRewardMixIn): + pass + + class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { "UMaze": [DistRewardUMaze, GoalRewardUMaze], @@ -244,6 +282,7 @@ class TaskRegistry: "Fall": [DistRewardFall, GoalRewardFall], "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], + "TRoom": [DistRewardTRoom, GoalRewardTRoom], } @staticmethod diff --git a/tests/test_envs.py b/tests/test_envs.py index 469ab1a..45346c4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -20,3 +20,11 @@ def test_point_maze(maze_id): assert env.reset().shape == (7,) s, _, _, _ = env.step(env.action_space.sample()) assert s.shape == (7,) + + +@pytest.mark.parametrize("v", [0, 1]) +def test_maze_args(v): + env = gym.make(f"PointTRoom-v{v}", task_kwargs={"goals": [(-2.0, 4.0)]}) + assert env.reset().shape == (7,) + s, _, _, _ = env.step(env.action_space.sample()) + assert s.shape == (7,) From fe06ba7c5b4a666a45e5d50533d00ce81ddcbb29 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 17 Sep 2020 01:57:00 +0900 Subject: [PATCH 2/2] Implement close --- mujoco_maze/agent_model.py | 7 +++++++ mujoco_maze/maze_env.py | 3 +++ tests/test_envs.py | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mujoco_maze/agent_model.py b/mujoco_maze/agent_model.py index 57c7de5..cefaa49 100644 --- a/mujoco_maze/agent_model.py +++ b/mujoco_maze/agent_model.py @@ -17,6 +17,13 @@ class AgentModel(ABC, MujocoEnv, EzPickle): MujocoEnv.__init__(self, file_path, frame_skip) EzPickle.__init__(self) + def close(self): + if self.viewer is not None and hasattr(self.viewer, "window"): + import glfw + + glfw.destroy_window(self.viewer.window) + super().close() + @abstractmethod def _get_obs(self) -> np.ndarray: """Returns the observation from the model. diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index d9fe7e8..0e8cbc0 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -453,3 +453,6 @@ class MazeEnv(gym.Env): done = self._task.termination(next_obs) info["position"] = self.wrapped_env.get_xy() return next_obs, inner_reward + outer_reward, done, info + + def close(self): + self.wrapped_env.close() diff --git a/tests/test_envs.py b/tests/test_envs.py index 45346c4..812aa29 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -24,7 +24,7 @@ def test_point_maze(maze_id): @pytest.mark.parametrize("v", [0, 1]) def test_maze_args(v): - env = gym.make(f"PointTRoom-v{v}", task_kwargs={"goals": [(-2.0, 4.0)]}) + env = gym.make(f"PointTRoom-v{v}", task_kwargs={"goals": [(-2.0, -4.0)]}) assert env.reset().shape == (7,) s, _, _, _ = env.step(env.action_space.sample()) assert s.shape == (7,)