Refactor using AgentModel

This commit is contained in:
kngwyu 2020-05-29 17:29:14 +09:00
parent 7287642a76
commit b77425efdb
5 changed files with 59 additions and 23 deletions

View File

@ -0,0 +1,39 @@
"""Common API definition for Ant and Point.
"""
from abc import ABC, abstractmethod
from gym.envs.mujoco.mujoco_env import MujocoEnv
from gym.utils import EzPickle
import numpy as np
from typing import Tuple
class AgentModel(ABC, MujocoEnv, EzPickle):
FILE: str
ORI_IND: int
def __init__(self, file_path: str, frame_skip: int) -> None:
MujocoEnv.__init__(self, file_path, frame_skip)
EzPickle.__init__(self)
@abstractmethod
def _get_obs(self) -> np.ndarray:
"""Returns the observation from the model.
"""
pass
@abstractmethod
def get_xy(self) -> Tuple[float, float]:
"""Returns the coordinate of the agent.
"""
pass
@abstractmethod
def set_xy(self, xy: Tuple[float, float]) -> None:
"""Set the coordinate of the agent.
"""
pass
@abstractmethod
def get_ori(self) -> float:
pass

View File

@ -17,8 +17,8 @@
import math import math
import numpy as np import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env from mujoco_maze.agent_model import AgentModel
def q_inv(a): def q_inv(a):
@ -33,7 +33,7 @@ def q_mult(a, b): # multiply two quaternion
return [w, i, j, k] return [w, i, j, k]
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): class AntEnv(AgentModel):
FILE = "ant.xml" FILE = "ant.xml"
ORI_IND = 3 ORI_IND = 3
@ -50,8 +50,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._body_com_indices = {} self._body_com_indices = {}
self._body_comvel_indices = {} self._body_comvel_indices = {}
mujoco_env.MujocoEnv.__init__(self, file_path, 5) super().__init__(file_path, 5)
utils.EzPickle.__init__(self)
def _step(self, a): def _step(self, a):
return self.step(a) return self.step(a)
@ -126,9 +125,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def get_ori(self): def get_ori(self):
ori = [0, 1, 0, 0] ori = [0, 1, 0, 0]
rot = self.sim.data.qpos[ ori_ind = self.ORI_IND
self.__class__.ORI_IND : self.__class__.ORI_IND + 4 rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion
] # take the quaternion
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
ori = math.atan2(ori[1], ori[0]) ori = math.atan2(ori[1], ori[0])
return ori return ori

View File

@ -22,6 +22,9 @@ import math
import numpy as np import numpy as np
import gym import gym
from typing import Type
from mujoco_maze.agent_model import AgentModel
from mujoco_maze import maze_env_utils from mujoco_maze import maze_env_utils
# Directory that contains mujoco xml files. # Directory that contains mujoco xml files.
@ -29,7 +32,7 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
class MazeEnv(gym.Env): class MazeEnv(gym.Env):
MODEL_CLASS = None MODEL_CLASS: Type[AgentModel] = AgentModel
MAZE_HEIGHT = None MAZE_HEIGHT = None
MAZE_SIZE_SCALING = None MAZE_SIZE_SCALING = None
@ -51,10 +54,7 @@ class MazeEnv(gym.Env):
): ):
self._maze_id = maze_id self._maze_id = maze_id
model_cls = self.__class__.MODEL_CLASS xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE)
if model_cls is None:
raise "MODEL_CLASS unspecified!"
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
tree = ET.parse(xml_path) tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody") worldbody = tree.find(".//worldbody")
@ -264,7 +264,7 @@ class MazeEnv(gym.Env):
_, file_path = tempfile.mkstemp(text=True, suffix=".xml") _, file_path = tempfile.mkstemp(text=True, suffix=".xml")
tree.write(file_path) tree.write(file_path)
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
def get_ori(self): def get_ori(self):
return self.wrapped_env.get_ori() return self.wrapped_env.get_ori()
@ -477,7 +477,7 @@ class MazeEnv(gym.Env):
self.t = 0 self.t = 0
self.wrapped_env.reset() self.wrapped_env.reset()
if len(self._init_positions) > 1: if len(self._init_positions) > 1:
xy = random.choice(self._init_positions) xy = np.random.choice(self._init_positions)
self.wrapped_env.set_xy(xy) self.wrapped_env.set_xy(xy)
return self._get_obs() return self._get_obs()

View File

@ -17,19 +17,18 @@
import math import math
import numpy as np import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env from mujoco_maze.agent_model import AgentModel
class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle): class PointEnv(AgentModel):
FILE = "point.xml" FILE = "point.xml"
ORI_IND = 2 ORI_IND = 2
def __init__(self, file_path=None, expose_all_qpos=True): def __init__(self, file_path=None, expose_all_qpos=True):
self._expose_all_qpos = expose_all_qpos self._expose_all_qpos = expose_all_qpos
mujoco_env.MujocoEnv.__init__(self, file_path, 1) super().__init__(file_path, 1)
utils.EzPickle.__init__(self)
def _step(self, a): def _step(self, a):
return self.step(a) return self.step(a)

View File

@ -10,9 +10,9 @@ repository = "https://github.com/kngwyu/mujoco-maze"
homepage = "https://github.com/kngwyu/mujoco-maze" homepage = "https://github.com/kngwyu/mujoco-maze"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.5" # Compatible python versions must be declared here python = "^3.6" # Compatible python versions must be declared here
gym = ">=0.14" gym = ">=0.16"
mujoco-py = ">=2.0" mujoco-py = ">=1.5"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^3.0" pytest = "^3.0"