Refactor using AgentModel
This commit is contained in:
parent
7287642a76
commit
b77425efdb
39
mujoco_maze/agent_model.py
Normal file
39
mujoco_maze/agent_model.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Common API definition for Ant and Point.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from gym.envs.mujoco.mujoco_env import MujocoEnv
|
||||
from gym.utils import EzPickle
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class AgentModel(ABC, MujocoEnv, EzPickle):
|
||||
FILE: str
|
||||
ORI_IND: int
|
||||
|
||||
def __init__(self, file_path: str, frame_skip: int) -> None:
|
||||
MujocoEnv.__init__(self, file_path, frame_skip)
|
||||
EzPickle.__init__(self)
|
||||
|
||||
@abstractmethod
|
||||
def _get_obs(self) -> np.ndarray:
|
||||
"""Returns the observation from the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_xy(self) -> Tuple[float, float]:
|
||||
"""Returns the coordinate of the agent.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_xy(self, xy: Tuple[float, float]) -> None:
|
||||
"""Set the coordinate of the agent.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_ori(self) -> float:
|
||||
pass
|
||||
|
@ -17,8 +17,8 @@
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from gym import utils
|
||||
from gym.envs.mujoco import mujoco_env
|
||||
|
||||
from mujoco_maze.agent_model import AgentModel
|
||||
|
||||
|
||||
def q_inv(a):
|
||||
@ -33,7 +33,7 @@ def q_mult(a, b): # multiply two quaternion
|
||||
return [w, i, j, k]
|
||||
|
||||
|
||||
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
class AntEnv(AgentModel):
|
||||
FILE = "ant.xml"
|
||||
ORI_IND = 3
|
||||
|
||||
@ -50,8 +50,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
self._body_com_indices = {}
|
||||
self._body_comvel_indices = {}
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, file_path, 5)
|
||||
utils.EzPickle.__init__(self)
|
||||
super().__init__(file_path, 5)
|
||||
|
||||
def _step(self, a):
|
||||
return self.step(a)
|
||||
@ -126,9 +125,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
def get_ori(self):
|
||||
ori = [0, 1, 0, 0]
|
||||
rot = self.sim.data.qpos[
|
||||
self.__class__.ORI_IND : self.__class__.ORI_IND + 4
|
||||
] # take the quaternion
|
||||
ori_ind = self.ORI_IND
|
||||
rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion
|
||||
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
|
||||
ori = math.atan2(ori[1], ori[0])
|
||||
return ori
|
||||
|
@ -22,6 +22,9 @@ import math
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from typing import Type
|
||||
|
||||
from mujoco_maze.agent_model import AgentModel
|
||||
from mujoco_maze import maze_env_utils
|
||||
|
||||
# Directory that contains mujoco xml files.
|
||||
@ -29,7 +32,7 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
|
||||
|
||||
|
||||
class MazeEnv(gym.Env):
|
||||
MODEL_CLASS = None
|
||||
MODEL_CLASS: Type[AgentModel] = AgentModel
|
||||
|
||||
MAZE_HEIGHT = None
|
||||
MAZE_SIZE_SCALING = None
|
||||
@ -51,10 +54,7 @@ class MazeEnv(gym.Env):
|
||||
):
|
||||
self._maze_id = maze_id
|
||||
|
||||
model_cls = self.__class__.MODEL_CLASS
|
||||
if model_cls is None:
|
||||
raise "MODEL_CLASS unspecified!"
|
||||
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
|
||||
xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE)
|
||||
tree = ET.parse(xml_path)
|
||||
worldbody = tree.find(".//worldbody")
|
||||
|
||||
@ -264,7 +264,7 @@ class MazeEnv(gym.Env):
|
||||
_, file_path = tempfile.mkstemp(text=True, suffix=".xml")
|
||||
tree.write(file_path)
|
||||
|
||||
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
|
||||
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
|
||||
|
||||
def get_ori(self):
|
||||
return self.wrapped_env.get_ori()
|
||||
@ -477,7 +477,7 @@ class MazeEnv(gym.Env):
|
||||
self.t = 0
|
||||
self.wrapped_env.reset()
|
||||
if len(self._init_positions) > 1:
|
||||
xy = random.choice(self._init_positions)
|
||||
xy = np.random.choice(self._init_positions)
|
||||
self.wrapped_env.set_xy(xy)
|
||||
return self._get_obs()
|
||||
|
||||
|
@ -17,19 +17,18 @@
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from gym import utils
|
||||
from gym.envs.mujoco import mujoco_env
|
||||
|
||||
from mujoco_maze.agent_model import AgentModel
|
||||
|
||||
|
||||
class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
class PointEnv(AgentModel):
|
||||
FILE = "point.xml"
|
||||
ORI_IND = 2
|
||||
|
||||
def __init__(self, file_path=None, expose_all_qpos=True):
|
||||
self._expose_all_qpos = expose_all_qpos
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, file_path, 1)
|
||||
utils.EzPickle.__init__(self)
|
||||
super().__init__(file_path, 1)
|
||||
|
||||
def _step(self, a):
|
||||
return self.step(a)
|
||||
|
@ -10,9 +10,9 @@ repository = "https://github.com/kngwyu/mujoco-maze"
|
||||
homepage = "https://github.com/kngwyu/mujoco-maze"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.5" # Compatible python versions must be declared here
|
||||
gym = ">=0.14"
|
||||
mujoco-py = ">=2.0"
|
||||
python = "^3.6" # Compatible python versions must be declared here
|
||||
gym = ">=0.16"
|
||||
mujoco-py = ">=1.5"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^3.0"
|
||||
|
Loading…
Reference in New Issue
Block a user