Refactor using AgentModel
This commit is contained in:
parent
7287642a76
commit
b77425efdb
39
mujoco_maze/agent_model.py
Normal file
39
mujoco_maze/agent_model.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
"""Common API definition for Ant and Point.
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from gym.envs.mujoco.mujoco_env import MujocoEnv
|
||||||
|
from gym.utils import EzPickle
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class AgentModel(ABC, MujocoEnv, EzPickle):
|
||||||
|
FILE: str
|
||||||
|
ORI_IND: int
|
||||||
|
|
||||||
|
def __init__(self, file_path: str, frame_skip: int) -> None:
|
||||||
|
MujocoEnv.__init__(self, file_path, frame_skip)
|
||||||
|
EzPickle.__init__(self)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_obs(self) -> np.ndarray:
|
||||||
|
"""Returns the observation from the model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_xy(self) -> Tuple[float, float]:
|
||||||
|
"""Returns the coordinate of the agent.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_xy(self, xy: Tuple[float, float]) -> None:
|
||||||
|
"""Set the coordinate of the agent.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_ori(self) -> float:
|
||||||
|
pass
|
||||||
|
|
@ -17,8 +17,8 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import utils
|
|
||||||
from gym.envs.mujoco import mujoco_env
|
from mujoco_maze.agent_model import AgentModel
|
||||||
|
|
||||||
|
|
||||||
def q_inv(a):
|
def q_inv(a):
|
||||||
@ -33,7 +33,7 @@ def q_mult(a, b): # multiply two quaternion
|
|||||||
return [w, i, j, k]
|
return [w, i, j, k]
|
||||||
|
|
||||||
|
|
||||||
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
class AntEnv(AgentModel):
|
||||||
FILE = "ant.xml"
|
FILE = "ant.xml"
|
||||||
ORI_IND = 3
|
ORI_IND = 3
|
||||||
|
|
||||||
@ -50,8 +50,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
self._body_com_indices = {}
|
self._body_com_indices = {}
|
||||||
self._body_comvel_indices = {}
|
self._body_comvel_indices = {}
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, file_path, 5)
|
super().__init__(file_path, 5)
|
||||||
utils.EzPickle.__init__(self)
|
|
||||||
|
|
||||||
def _step(self, a):
|
def _step(self, a):
|
||||||
return self.step(a)
|
return self.step(a)
|
||||||
@ -126,9 +125,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def get_ori(self):
|
def get_ori(self):
|
||||||
ori = [0, 1, 0, 0]
|
ori = [0, 1, 0, 0]
|
||||||
rot = self.sim.data.qpos[
|
ori_ind = self.ORI_IND
|
||||||
self.__class__.ORI_IND : self.__class__.ORI_IND + 4
|
rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion
|
||||||
] # take the quaternion
|
|
||||||
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
|
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
|
||||||
ori = math.atan2(ori[1], ori[0])
|
ori = math.atan2(ori[1], ori[0])
|
||||||
return ori
|
return ori
|
||||||
|
@ -22,6 +22,9 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from mujoco_maze.agent_model import AgentModel
|
||||||
from mujoco_maze import maze_env_utils
|
from mujoco_maze import maze_env_utils
|
||||||
|
|
||||||
# Directory that contains mujoco xml files.
|
# Directory that contains mujoco xml files.
|
||||||
@ -29,7 +32,7 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
|
|||||||
|
|
||||||
|
|
||||||
class MazeEnv(gym.Env):
|
class MazeEnv(gym.Env):
|
||||||
MODEL_CLASS = None
|
MODEL_CLASS: Type[AgentModel] = AgentModel
|
||||||
|
|
||||||
MAZE_HEIGHT = None
|
MAZE_HEIGHT = None
|
||||||
MAZE_SIZE_SCALING = None
|
MAZE_SIZE_SCALING = None
|
||||||
@ -51,10 +54,7 @@ class MazeEnv(gym.Env):
|
|||||||
):
|
):
|
||||||
self._maze_id = maze_id
|
self._maze_id = maze_id
|
||||||
|
|
||||||
model_cls = self.__class__.MODEL_CLASS
|
xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE)
|
||||||
if model_cls is None:
|
|
||||||
raise "MODEL_CLASS unspecified!"
|
|
||||||
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
|
|
||||||
tree = ET.parse(xml_path)
|
tree = ET.parse(xml_path)
|
||||||
worldbody = tree.find(".//worldbody")
|
worldbody = tree.find(".//worldbody")
|
||||||
|
|
||||||
@ -264,7 +264,7 @@ class MazeEnv(gym.Env):
|
|||||||
_, file_path = tempfile.mkstemp(text=True, suffix=".xml")
|
_, file_path = tempfile.mkstemp(text=True, suffix=".xml")
|
||||||
tree.write(file_path)
|
tree.write(file_path)
|
||||||
|
|
||||||
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
|
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
|
||||||
|
|
||||||
def get_ori(self):
|
def get_ori(self):
|
||||||
return self.wrapped_env.get_ori()
|
return self.wrapped_env.get_ori()
|
||||||
@ -477,7 +477,7 @@ class MazeEnv(gym.Env):
|
|||||||
self.t = 0
|
self.t = 0
|
||||||
self.wrapped_env.reset()
|
self.wrapped_env.reset()
|
||||||
if len(self._init_positions) > 1:
|
if len(self._init_positions) > 1:
|
||||||
xy = random.choice(self._init_positions)
|
xy = np.random.choice(self._init_positions)
|
||||||
self.wrapped_env.set_xy(xy)
|
self.wrapped_env.set_xy(xy)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@ -17,19 +17,18 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import utils
|
|
||||||
from gym.envs.mujoco import mujoco_env
|
from mujoco_maze.agent_model import AgentModel
|
||||||
|
|
||||||
|
|
||||||
class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
class PointEnv(AgentModel):
|
||||||
FILE = "point.xml"
|
FILE = "point.xml"
|
||||||
ORI_IND = 2
|
ORI_IND = 2
|
||||||
|
|
||||||
def __init__(self, file_path=None, expose_all_qpos=True):
|
def __init__(self, file_path=None, expose_all_qpos=True):
|
||||||
self._expose_all_qpos = expose_all_qpos
|
self._expose_all_qpos = expose_all_qpos
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, file_path, 1)
|
super().__init__(file_path, 1)
|
||||||
utils.EzPickle.__init__(self)
|
|
||||||
|
|
||||||
def _step(self, a):
|
def _step(self, a):
|
||||||
return self.step(a)
|
return self.step(a)
|
||||||
|
@ -10,9 +10,9 @@ repository = "https://github.com/kngwyu/mujoco-maze"
|
|||||||
homepage = "https://github.com/kngwyu/mujoco-maze"
|
homepage = "https://github.com/kngwyu/mujoco-maze"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.5" # Compatible python versions must be declared here
|
python = "^3.6" # Compatible python versions must be declared here
|
||||||
gym = ">=0.14"
|
gym = ">=0.16"
|
||||||
mujoco-py = ">=2.0"
|
mujoco-py = ">=1.5"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "^3.0"
|
pytest = "^3.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user