Refactor using AgentModel

This commit is contained in:
kngwyu 2020-05-29 17:29:14 +09:00
parent 7287642a76
commit b77425efdb
5 changed files with 59 additions and 23 deletions

View File

@ -0,0 +1,39 @@
"""Common API definition for Ant and Point.
"""
from abc import ABC, abstractmethod
from gym.envs.mujoco.mujoco_env import MujocoEnv
from gym.utils import EzPickle
import numpy as np
from typing import Tuple
class AgentModel(ABC, MujocoEnv, EzPickle):
FILE: str
ORI_IND: int
def __init__(self, file_path: str, frame_skip: int) -> None:
MujocoEnv.__init__(self, file_path, frame_skip)
EzPickle.__init__(self)
@abstractmethod
def _get_obs(self) -> np.ndarray:
"""Returns the observation from the model.
"""
pass
@abstractmethod
def get_xy(self) -> Tuple[float, float]:
"""Returns the coordinate of the agent.
"""
pass
@abstractmethod
def set_xy(self, xy: Tuple[float, float]) -> None:
"""Set the coordinate of the agent.
"""
pass
@abstractmethod
def get_ori(self) -> float:
pass

View File

@ -17,8 +17,8 @@
import math
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
from mujoco_maze.agent_model import AgentModel
def q_inv(a):
@ -33,7 +33,7 @@ def q_mult(a, b): # multiply two quaternion
return [w, i, j, k]
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
class AntEnv(AgentModel):
FILE = "ant.xml"
ORI_IND = 3
@ -50,8 +50,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._body_com_indices = {}
self._body_comvel_indices = {}
mujoco_env.MujocoEnv.__init__(self, file_path, 5)
utils.EzPickle.__init__(self)
super().__init__(file_path, 5)
def _step(self, a):
return self.step(a)
@ -126,9 +125,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def get_ori(self):
ori = [0, 1, 0, 0]
rot = self.sim.data.qpos[
self.__class__.ORI_IND : self.__class__.ORI_IND + 4
] # take the quaternion
ori_ind = self.ORI_IND
rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
ori = math.atan2(ori[1], ori[0])
return ori

View File

@ -22,6 +22,9 @@ import math
import numpy as np
import gym
from typing import Type
from mujoco_maze.agent_model import AgentModel
from mujoco_maze import maze_env_utils
# Directory that contains mujoco xml files.
@ -29,7 +32,7 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
class MazeEnv(gym.Env):
MODEL_CLASS = None
MODEL_CLASS: Type[AgentModel] = AgentModel
MAZE_HEIGHT = None
MAZE_SIZE_SCALING = None
@ -51,10 +54,7 @@ class MazeEnv(gym.Env):
):
self._maze_id = maze_id
model_cls = self.__class__.MODEL_CLASS
if model_cls is None:
raise "MODEL_CLASS unspecified!"
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE)
tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody")
@ -264,7 +264,7 @@ class MazeEnv(gym.Env):
_, file_path = tempfile.mkstemp(text=True, suffix=".xml")
tree.write(file_path)
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
def get_ori(self):
return self.wrapped_env.get_ori()
@ -477,7 +477,7 @@ class MazeEnv(gym.Env):
self.t = 0
self.wrapped_env.reset()
if len(self._init_positions) > 1:
xy = random.choice(self._init_positions)
xy = np.random.choice(self._init_positions)
self.wrapped_env.set_xy(xy)
return self._get_obs()

View File

@ -17,19 +17,18 @@
import math
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
from mujoco_maze.agent_model import AgentModel
class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
class PointEnv(AgentModel):
FILE = "point.xml"
ORI_IND = 2
def __init__(self, file_path=None, expose_all_qpos=True):
self._expose_all_qpos = expose_all_qpos
mujoco_env.MujocoEnv.__init__(self, file_path, 1)
utils.EzPickle.__init__(self)
super().__init__(file_path, 1)
def _step(self, a):
return self.step(a)

View File

@ -10,9 +10,9 @@ repository = "https://github.com/kngwyu/mujoco-maze"
homepage = "https://github.com/kngwyu/mujoco-maze"
[tool.poetry.dependencies]
python = "^3.5" # Compatible python versions must be declared here
gym = ">=0.14"
mujoco-py = ">=2.0"
python = "^3.6" # Compatible python versions must be declared here
gym = ">=0.16"
mujoco-py = ">=1.5"
[tool.poetry.dev-dependencies]
pytest = "^3.0"