diff --git a/mujoco_maze/alt_rendering.py b/mujoco_maze/alt_rendering.py new file mode 100644 index 0000000..4aa65b9 --- /dev/null +++ b/mujoco_maze/alt_rendering.py @@ -0,0 +1,92 @@ +""" +2D rendering framework +""" +import os +import sys + +if "Apple" in sys.version: + if "DYLD_FALLBACK_LIBRARY_PATH" in os.environ: + os.environ["DYLD_FALLBACK_LIBRARY_PATH"] += ":/usr/lib" + # (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite + +from gym import error + + +import pyglet + +from pyglet import gl + + +RAD2DEG = 57.29577951308232 + + +def get_display(spec): + """Convert a display specification (such as :0) into an actual Display + object. + + Pyglet only supports multiple Displays on Linux. + """ + if spec is None: + return pyglet.canvas.get_display() + # returns already available pyglet_display, + # if there is no pyglet display available then it creates one + elif isinstance(spec, str): + return pyglet.canvas.Display(spec) + else: + raise error.Error( + f"Invalid display specification: {spec}. (Must be a string like :0 or None.)" + ) + + +def get_window(width, height, display, **kwargs): + """ + Will create a pyglet window from the display specification provided. + """ + screen = display.get_screens() # available screens + config = screen[0].get_best_config() # selecting the first screen + context = config.create_context(None) # create GL context + + return pyglet.window.Window( + width=width, + height=height, + display=display, + config=config, + context=context, + **kwargs, + ) + + +class ImageViewer: + def __init__(self, width, height, display=None): + display = get_display(display) + + self.width = width + self.height = height + self.window = get_window(width=width, height=height, display=display) + self.window.on_close = self.window_closed_by_user + self.isopen = True + + gl.glEnable(gl.GL_BLEND) + gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) + + def imshow(self, arr): + assert len(arr.shape) == 3, "You passed in an image with the wrong number shape" + self.window.clear() + self.window.switch_to() + self.window.dispatch_events() + image = pyglet.image.ImageData( + arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3 + ) + image.blit(0, 0) # draw + self.window.flip() + + def close(self): + if self.isopen and sys.meta_path: + self.window.close() + self.isopen = False + + def window_closed_by_user(self): + self.isopen = False + + def __del__(self): + self.close() diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 7619c35..51cc7d7 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -10,7 +10,7 @@ import itertools as it import os import tempfile import xml.etree.ElementTree as ET -from typing import List, Tuple, Type +from typing import Any, List, Tuple, Type import gym import numpy as np @@ -33,7 +33,7 @@ class MazeEnv(gym.Env): inner_reward_scaling: float = 1.0, restitution_coef: float = 0.8, task_kwargs: dict = {}, - *args, + use_alt_viewer: bool = False, **kwargs, ) -> None: self._task = maze_task(maze_size_scaling, **task_kwargs) @@ -70,11 +70,19 @@ class MazeEnv(gym.Env): if model_cls.RADIUS is None: raise ValueError("Manual collision needs radius of the model") self._collision = maze_env_utils.CollisionDetector( - structure, size_scaling, torso_x, torso_y, model_cls.RADIUS, + structure, + size_scaling, + torso_x, + torso_y, + model_cls.RADIUS, ) # Now all object balls have size=1.0 self._objball_collision = maze_env_utils.CollisionDetector( - structure, size_scaling, torso_x, torso_y, self._task.OBJECT_BALL_SIZE, + structure, + size_scaling, + torso_x, + torso_y, + self._task.OBJECT_BALL_SIZE, ) else: self._collision = None @@ -141,7 +149,15 @@ class MazeEnv(gym.Env): # Movable block. self.movable_blocks.append(f"movable_{i}_{j}") _add_movable_block( - worldbody, struct, i, j, size_scaling, x, y, h, height_offset, + worldbody, + struct, + i, + j, + size_scaling, + x, + y, + h, + height_offset, ) elif struct.is_object_ball(): # Movable Ball @@ -173,8 +189,10 @@ class MazeEnv(gym.Env): _, file_path = tempfile.mkstemp(text=True, suffix=".xml") tree.write(file_path) self.world_tree = tree - self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) + self.wrapped_env = model_cls(file_path=file_path, **kwargs) self.observation_space = self._get_obs_space() + self._use_alt_viewer = use_alt_viewer + self._alt_viewer = None @property def has_extended_obs(self) -> bool: @@ -337,11 +355,26 @@ class MazeEnv(gym.Env): self.data.site_xpos[idx][: len(goal.pos)] = goal.pos @property - def viewer(self): - return self.wrapped_env.viewer + def viewer(self) -> Any: + if self._use_alt_viewer: + return self._alt_viewer + else: + return self.wrapped_env.viewer - def render(self, *args, **kwargs): - return self.wrapped_env.render(*args, **kwargs) + def render(self, mode="human", **kwargs) -> Any: + if self._use_alt_viewer: + image = self.wrapped_env.sim.render(640, 480)[::-1, :, :] + if self._alt_viewer is None: + from mujoco_maze.alt_rendering import ImageViewer + + self._alt_viewer = ImageViewer(640, 480) + if mode == "rgb_array": + return image + else: + self._alt_viewer.imshow(image) + return self._alt_viewer.isopen + else: + return self.wrapped_env.render(mode=mode, **kwargs) @property def action_space(self): @@ -406,6 +439,8 @@ class MazeEnv(gym.Env): def close(self) -> None: self.wrapped_env.close() + if self._alt_viewer is not None: + self._alt_viewer.close() def _add_object_ball( @@ -479,7 +514,10 @@ def _add_movable_block( shrink = 1.0 size = size_scaling * 0.5 * shrink movable_body = ET.SubElement( - worldbody, "body", name=f"movable_{i}_{j}", pos=f"{x} {y} {h}", + worldbody, + "body", + name=f"movable_{i}_{j}", + pos=f"{x} {y} {h}", ) ET.SubElement( movable_body, diff --git a/mujoco_maze/point.py b/mujoco_maze/point.py index 4e8c071..6f64603 100644 --- a/mujoco_maze/point.py +++ b/mujoco_maze/point.py @@ -22,7 +22,7 @@ class PointEnv(AgentModel): VELOCITY_LIMITS: float = 10.0 - def __init__(self, file_path: Optional[str] = None): + def __init__(self, file_path: Optional[str] = None) -> None: super().__init__(file_path, 1) high = np.inf * np.ones(6, dtype=np.float32) high[3:] = self.VELOCITY_LIMITS * 1.2