Merge pull request #13 from DigitalRev0lution/gym-v0.26.0
adjust for gym0.26.0
This commit is contained in:
		
						commit
						fde62d9903
					
				| @ -6,6 +6,7 @@ from typing import Optional | |||||||
| import numpy as np | import numpy as np | ||||||
| from gym.envs.mujoco.mujoco_env import MujocoEnv | from gym.envs.mujoco.mujoco_env import MujocoEnv | ||||||
| from gym.utils import EzPickle | from gym.utils import EzPickle | ||||||
|  | from gym.spaces import Space | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentModel(ABC, MujocoEnv, EzPickle): | class AgentModel(ABC, MujocoEnv, EzPickle): | ||||||
| @ -15,8 +16,8 @@ class AgentModel(ABC, MujocoEnv, EzPickle): | |||||||
|     RADIUS: Optional[float] = None |     RADIUS: Optional[float] = None | ||||||
|     OBJBALL_TYPE: Optional[str] = None |     OBJBALL_TYPE: Optional[str] = None | ||||||
| 
 | 
 | ||||||
|     def __init__(self, file_path: str, frame_skip: int) -> None: |     def __init__(self, file_path: str, frame_skip: int, observation_space: Space) -> None: | ||||||
|         MujocoEnv.__init__(self, file_path, frame_skip) |         MujocoEnv.__init__(self, file_path, frame_skip, observation_space) | ||||||
|         EzPickle.__init__(self) |         EzPickle.__init__(self) | ||||||
| 
 | 
 | ||||||
|     def close(self): |     def close(self): | ||||||
|  | |||||||
| @ -18,6 +18,8 @@ import numpy as np | |||||||
| from mujoco_maze import maze_env_utils, maze_task | from mujoco_maze import maze_env_utils, maze_task | ||||||
| from mujoco_maze.agent_model import AgentModel | from mujoco_maze.agent_model import AgentModel | ||||||
| 
 | 
 | ||||||
|  | from gym.core import ObsType | ||||||
|  | 
 | ||||||
| # Directory that contains mujoco xml files. | # Directory that contains mujoco xml files. | ||||||
| MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets" | MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets" | ||||||
| 
 | 
 | ||||||
| @ -366,7 +368,7 @@ class MazeEnv(gym.Env): | |||||||
|         obs = np.concatenate([wrapped_obs[:3]] + additional_obs + [wrapped_obs[3:]]) |         obs = np.concatenate([wrapped_obs[:3]] + additional_obs + [wrapped_obs[3:]]) | ||||||
|         return np.concatenate([obs, *view, np.array([self.t * 0.001])]) |         return np.concatenate([obs, *view, np.array([self.t * 0.001])]) | ||||||
| 
 | 
 | ||||||
|     def reset(self) -> np.ndarray: |     def reset(self, **kwargs) -> Tuple[ObsType, dict]: | ||||||
|         self.t = 0 |         self.t = 0 | ||||||
|         self.wrapped_env.reset() |         self.wrapped_env.reset() | ||||||
|         # Samples a new goal |         # Samples a new goal | ||||||
| @ -376,7 +378,8 @@ class MazeEnv(gym.Env): | |||||||
|         if len(self._init_positions) > 1: |         if len(self._init_positions) > 1: | ||||||
|             xy = np.random.choice(self._init_positions) |             xy = np.random.choice(self._init_positions) | ||||||
|             self.wrapped_env.set_xy(xy) |             self.wrapped_env.set_xy(xy) | ||||||
|         return self._get_obs() |         info = {} | ||||||
|  |         return self._get_obs(), info | ||||||
| 
 | 
 | ||||||
|     def set_marker(self) -> None: |     def set_marker(self) -> None: | ||||||
|         for i, goal in enumerate(self._task.goals): |         for i, goal in enumerate(self._task.goals): | ||||||
| @ -410,10 +413,11 @@ class MazeEnv(gym.Env): | |||||||
|                 self._websock_server_pipe = start_server(self._websock_port) |                 self._websock_server_pipe = start_server(self._websock_port) | ||||||
|             return self._websock_server_pipe.send(self._render_image()) |             return self._websock_server_pipe.send(self._render_image()) | ||||||
|         else: |         else: | ||||||
|  |             self.wrapped_env.render_mode = mode | ||||||
|             if self.wrapped_env.viewer is None: |             if self.wrapped_env.viewer is None: | ||||||
|                 self.wrapped_env.render(mode, **kwargs) |                 self.wrapped_env.render() | ||||||
|                 self._maybe_move_camera(self.wrapped_env.viewer) |                 self._maybe_move_camera(self.wrapped_env.viewer) | ||||||
|             return self.wrapped_env.render(mode, **kwargs) |             return self.wrapped_env.render() | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def action_space(self): |     def action_space(self): | ||||||
|  | |||||||
| @ -9,12 +9,22 @@ Based on `models`_ and `rllab`_. | |||||||
| from typing import Optional, Tuple | from typing import Optional, Tuple | ||||||
| 
 | 
 | ||||||
| import gym | import gym | ||||||
|  | import mujoco | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| from mujoco_maze.agent_model import AgentModel | from mujoco_maze.agent_model import AgentModel | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class PointEnv(AgentModel): | class PointEnv(AgentModel): | ||||||
|  |     metadata = { | ||||||
|  |         "render_modes": [ | ||||||
|  |             "human", | ||||||
|  |             "rgb_array", | ||||||
|  |             "depth_array", | ||||||
|  |         ], | ||||||
|  |         "render_fps": 50, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     FILE: str = "point.xml" |     FILE: str = "point.xml" | ||||||
|     ORI_IND: int = 2 |     ORI_IND: int = 2 | ||||||
|     MANUAL_COLLISION: bool = True |     MANUAL_COLLISION: bool = True | ||||||
| @ -24,15 +34,15 @@ class PointEnv(AgentModel): | |||||||
|     VELOCITY_LIMITS: float = 10.0 |     VELOCITY_LIMITS: float = 10.0 | ||||||
| 
 | 
 | ||||||
|     def __init__(self, file_path: Optional[str] = None) -> None: |     def __init__(self, file_path: Optional[str] = None) -> None: | ||||||
|         super().__init__(file_path, 1) |  | ||||||
|         high = np.inf * np.ones(6, dtype=np.float32) |         high = np.inf * np.ones(6, dtype=np.float32) | ||||||
|         high[3:] = self.VELOCITY_LIMITS * 1.2 |         high[3:] = self.VELOCITY_LIMITS * 1.2 | ||||||
|         high[self.ORI_IND] = np.pi |         high[self.ORI_IND] = np.pi | ||||||
|         low = -high |         low = -high | ||||||
|         self.observation_space = gym.spaces.Box(low, high) |         observation_space = gym.spaces.Box(low, high) | ||||||
|  |         super().__init__(file_path, 1, observation_space) | ||||||
| 
 | 
 | ||||||
|     def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: |     def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: | ||||||
|         qpos = self.sim.data.qpos.copy() |         qpos = self.data.qpos.copy() | ||||||
|         qpos[2] += action[1] |         qpos[2] += action[1] | ||||||
|         # Clip orientation |         # Clip orientation | ||||||
|         if qpos[2] < -np.pi: |         if qpos[2] < -np.pi: | ||||||
| @ -43,26 +53,26 @@ class PointEnv(AgentModel): | |||||||
|         # Compute increment in each direction |         # Compute increment in each direction | ||||||
|         qpos[0] += np.cos(ori) * action[0] |         qpos[0] += np.cos(ori) * action[0] | ||||||
|         qpos[1] += np.sin(ori) * action[0] |         qpos[1] += np.sin(ori) * action[0] | ||||||
|         qvel = np.clip(self.sim.data.qvel, -self.VELOCITY_LIMITS, self.VELOCITY_LIMITS) |         qvel = np.clip(self.data.qvel, -self.VELOCITY_LIMITS, self.VELOCITY_LIMITS) | ||||||
|         self.set_state(qpos, qvel) |         self.set_state(qpos, qvel) | ||||||
|         for _ in range(0, self.frame_skip): |         for _ in range(0, self.frame_skip): | ||||||
|             self.sim.step() |             mujoco.mj_step(self.model, self.data) | ||||||
|         next_obs = self._get_obs() |         next_obs = self._get_obs() | ||||||
|         return next_obs, 0.0, False, {} |         return next_obs, 0.0, False, {} | ||||||
| 
 | 
 | ||||||
|     def _get_obs(self): |     def _get_obs(self): | ||||||
|         return np.concatenate( |         return np.concatenate( | ||||||
|             [ |             [ | ||||||
|                 self.sim.data.qpos.flat[:3],  # Only point-relevant coords. |                 self.data.qpos.flat[:3],  # Only point-relevant coords. | ||||||
|                 self.sim.data.qvel.flat[:3], |                 self.data.qvel.flat[:3], | ||||||
|             ] |             ] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def reset_model(self): |     def reset_model(self): | ||||||
|         qpos = self.init_qpos + self.np_random.uniform( |         qpos = self.init_qpos + self.np_random.uniform( | ||||||
|             size=self.sim.model.nq, low=-0.1, high=0.1 |             size=self.model.nq, low=-0.1, high=0.1 | ||||||
|         ) |         ) | ||||||
|         qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * 0.1 |         qvel = self.init_qvel + self.np_random.random(self.model.nv) * 0.1 | ||||||
| 
 | 
 | ||||||
|         # Set everything other than point to original position and 0 velocity. |         # Set everything other than point to original position and 0 velocity. | ||||||
|         qpos[3:] = self.init_qpos[3:] |         qpos[3:] = self.init_qpos[3:] | ||||||
| @ -71,12 +81,12 @@ class PointEnv(AgentModel): | |||||||
|         return self._get_obs() |         return self._get_obs() | ||||||
| 
 | 
 | ||||||
|     def get_xy(self): |     def get_xy(self): | ||||||
|         return self.sim.data.qpos[:2].copy() |         return self.data.qpos[:2].copy() | ||||||
| 
 | 
 | ||||||
|     def set_xy(self, xy: np.ndarray) -> None: |     def set_xy(self, xy: np.ndarray) -> None: | ||||||
|         qpos = self.sim.data.qpos.copy() |         qpos = self.data.qpos.copy() | ||||||
|         qpos[:2] = xy |         qpos[:2] = xy | ||||||
|         self.set_state(qpos, self.sim.data.qvel) |         self.set_state(qpos, self.data.qvel) | ||||||
| 
 | 
 | ||||||
|     def get_ori(self): |     def get_ori(self): | ||||||
|         return self.sim.data.qpos[self.ORI_IND] |         return self.data.qpos[self.ORI_IND] | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user