From 74319808385020c1101344cd75e33e9739fbccd4 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Sun, 31 May 2020 01:35:00 +0900 Subject: [PATCH] Modify collision detection --- mujoco_maze/__init__.py | 14 +++++ mujoco_maze/agent_model.py | 5 +- mujoco_maze/assets/point.xml | 23 ++++---- mujoco_maze/maze_env.py | 101 ++++++++++++++++------------------ mujoco_maze/maze_env_utils.py | 12 ++-- mujoco_maze/point.py | 3 +- 6 files changed, 82 insertions(+), 76 deletions(-) diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py index 9779c0b..5f00d6e 100644 --- a/mujoco_maze/__init__.py +++ b/mujoco_maze/__init__.py @@ -19,6 +19,13 @@ for maze_id in MAZE_IDS: max_episode_steps=1000, reward_threshold=-1000, ) + gym.envs.register( + id="Ant{}-v1".format(maze_id), + entry_point="mujoco_maze.ant_maze_env:AntMazeEnv", + kwargs=dict(maze_size_scaling=8.0, **_get_kwargs(maze_id)), + max_episode_steps=1000, + reward_threshold=0.9, + ) for maze_id in MAZE_IDS: gym.envs.register( @@ -28,6 +35,13 @@ for maze_id in MAZE_IDS: max_episode_steps=1000, reward_threshold=-1000, ) + gym.envs.register( + id="Point{}-v1".format(maze_id), + entry_point="mujoco_maze.point_maze_env:PointMazeEnv", + kwargs=dict(**_get_kwargs(maze_id), dense_reward=False), + max_episode_steps=1000, + reward_threshold=0.9 + ) __version__ = "0.1.0" diff --git a/mujoco_maze/agent_model.py b/mujoco_maze/agent_model.py index 627c1b5..63fbd3d 100644 --- a/mujoco_maze/agent_model.py +++ b/mujoco_maze/agent_model.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from gym.envs.mujoco.mujoco_env import MujocoEnv from gym.utils import EzPickle import numpy as np -from typing import Tuple class AgentModel(ABC, MujocoEnv, EzPickle): @@ -22,13 +21,13 @@ class AgentModel(ABC, MujocoEnv, EzPickle): pass @abstractmethod - def get_xy(self) -> Tuple[float, float]: + def get_xy(self) -> np.ndarray: """Returns the coordinate of the agent. """ pass @abstractmethod - def set_xy(self, xy: Tuple[float, float]) -> None: + def set_xy(self, xy: np.ndarray) -> None: """Set the coordinate of the agent. """ pass diff --git a/mujoco_maze/assets/point.xml b/mujoco_maze/assets/point.xml index c382e16..fe2d2e5 100755 --- a/mujoco_maze/assets/point.xml +++ b/mujoco_maze/assets/point.xml @@ -3,29 +3,30 @@