From b77425efdbc2fd4f53eea2dd2a9e4e9ea623724f Mon Sep 17 00:00:00 2001 From: kngwyu Date: Fri, 29 May 2020 17:29:14 +0900 Subject: [PATCH] Refactor using AgentModel --- mujoco_maze/agent_model.py | 39 ++++++++++++++++++++++++++++++++++++++ mujoco_maze/ant.py | 14 ++++++-------- mujoco_maze/maze_env.py | 14 +++++++------- mujoco_maze/point.py | 9 ++++----- pyproject.toml | 6 +++--- 5 files changed, 59 insertions(+), 23 deletions(-) create mode 100644 mujoco_maze/agent_model.py diff --git a/mujoco_maze/agent_model.py b/mujoco_maze/agent_model.py new file mode 100644 index 0000000..d2b7049 --- /dev/null +++ b/mujoco_maze/agent_model.py @@ -0,0 +1,39 @@ +"""Common API definition for Ant and Point. +""" +from abc import ABC, abstractmethod +from gym.envs.mujoco.mujoco_env import MujocoEnv +from gym.utils import EzPickle +import numpy as np +from typing import Tuple + + +class AgentModel(ABC, MujocoEnv, EzPickle): + FILE: str + ORI_IND: int + + def __init__(self, file_path: str, frame_skip: int) -> None: + MujocoEnv.__init__(self, file_path, frame_skip) + EzPickle.__init__(self) + + @abstractmethod + def _get_obs(self) -> np.ndarray: + """Returns the observation from the model. + """ + pass + + @abstractmethod + def get_xy(self) -> Tuple[float, float]: + """Returns the coordinate of the agent. + """ + pass + + @abstractmethod + def set_xy(self, xy: Tuple[float, float]) -> None: + """Set the coordinate of the agent. + """ + pass + + @abstractmethod + def get_ori(self) -> float: + pass + diff --git a/mujoco_maze/ant.py b/mujoco_maze/ant.py index 48a9bc2..cad281a 100644 --- a/mujoco_maze/ant.py +++ b/mujoco_maze/ant.py @@ -17,8 +17,8 @@ import math import numpy as np -from gym import utils -from gym.envs.mujoco import mujoco_env + +from mujoco_maze.agent_model import AgentModel def q_inv(a): @@ -33,7 +33,7 @@ def q_mult(a, b): # multiply two quaternion return [w, i, j, k] -class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): +class AntEnv(AgentModel): FILE = "ant.xml" ORI_IND = 3 @@ -50,8 +50,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): self._body_com_indices = {} self._body_comvel_indices = {} - mujoco_env.MujocoEnv.__init__(self, file_path, 5) - utils.EzPickle.__init__(self) + super().__init__(file_path, 5) def _step(self, a): return self.step(a) @@ -126,9 +125,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): def get_ori(self): ori = [0, 1, 0, 0] - rot = self.sim.data.qpos[ - self.__class__.ORI_IND : self.__class__.ORI_IND + 4 - ] # take the quaternion + ori_ind = self.ORI_IND + rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane ori = math.atan2(ori[1], ori[0]) return ori diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index ef19476..369ccec 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -22,6 +22,9 @@ import math import numpy as np import gym +from typing import Type + +from mujoco_maze.agent_model import AgentModel from mujoco_maze import maze_env_utils # Directory that contains mujoco xml files. @@ -29,7 +32,7 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets" class MazeEnv(gym.Env): - MODEL_CLASS = None + MODEL_CLASS: Type[AgentModel] = AgentModel MAZE_HEIGHT = None MAZE_SIZE_SCALING = None @@ -51,10 +54,7 @@ class MazeEnv(gym.Env): ): self._maze_id = maze_id - model_cls = self.__class__.MODEL_CLASS - if model_cls is None: - raise "MODEL_CLASS unspecified!" - xml_path = os.path.join(MODEL_DIR, model_cls.FILE) + xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE) tree = ET.parse(xml_path) worldbody = tree.find(".//worldbody") @@ -264,7 +264,7 @@ class MazeEnv(gym.Env): _, file_path = tempfile.mkstemp(text=True, suffix=".xml") tree.write(file_path) - self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) + self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs) def get_ori(self): return self.wrapped_env.get_ori() @@ -477,7 +477,7 @@ class MazeEnv(gym.Env): self.t = 0 self.wrapped_env.reset() if len(self._init_positions) > 1: - xy = random.choice(self._init_positions) + xy = np.random.choice(self._init_positions) self.wrapped_env.set_xy(xy) return self._get_obs() diff --git a/mujoco_maze/point.py b/mujoco_maze/point.py index d209820..09b441d 100644 --- a/mujoco_maze/point.py +++ b/mujoco_maze/point.py @@ -17,19 +17,18 @@ import math import numpy as np -from gym import utils -from gym.envs.mujoco import mujoco_env + +from mujoco_maze.agent_model import AgentModel -class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle): +class PointEnv(AgentModel): FILE = "point.xml" ORI_IND = 2 def __init__(self, file_path=None, expose_all_qpos=True): self._expose_all_qpos = expose_all_qpos - mujoco_env.MujocoEnv.__init__(self, file_path, 1) - utils.EzPickle.__init__(self) + super().__init__(file_path, 1) def _step(self, a): return self.step(a) diff --git a/pyproject.toml b/pyproject.toml index 224f40d..28f0097 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,9 @@ repository = "https://github.com/kngwyu/mujoco-maze" homepage = "https://github.com/kngwyu/mujoco-maze" [tool.poetry.dependencies] -python = "^3.5" # Compatible python versions must be declared here -gym = ">=0.14" -mujoco-py = ">=2.0" +python = "^3.6" # Compatible python versions must be declared here +gym = ">=0.16" +mujoco-py = ">=1.5" [tool.poetry.dev-dependencies] pytest = "^3.0"