diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index ac80885..9f6d2e1 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -10,7 +10,7 @@ import itertools as it import os import tempfile import xml.etree.ElementTree as ET -from typing import Any, List, Tuple, Type +from typing import Any, List, Optional, Tuple, Type import gym import numpy as np @@ -33,7 +33,7 @@ class MazeEnv(gym.Env): inner_reward_scaling: float = 1.0, restitution_coef: float = 0.8, task_kwargs: dict = {}, - use_alt_viewer: bool = False, + websock_render_port: Optional[int] = None, **kwargs, ) -> None: self._task = maze_task(maze_size_scaling, **task_kwargs) @@ -191,8 +191,9 @@ class MazeEnv(gym.Env): self.world_tree = tree self.wrapped_env = model_cls(file_path=file_path, **kwargs) self.observation_space = self._get_obs_space() - self._use_alt_viewer = use_alt_viewer - self._alt_viewer = None + self._websock_port = websock_port + self._mj_offscreen_viewer = None + self._websock_server_pipe = None @property def has_extended_obs(self) -> bool: @@ -356,36 +357,75 @@ class MazeEnv(gym.Env): @property def viewer(self) -> Any: - if self._use_alt_viewer: - return self._alt_viewer + if self._websock_port is not None: + return self._mj_viewer else: return self.wrapped_env.viewer - def render(self, mode="human", **kwargs) -> Any: - if self._use_alt_viewer: - if self._alt_viewer is None: - from gym.envs.classic_control.rendering import SimpleImageViewer - import mujoco_py + def _setup_websocket_server(self) -> None: + import multiprocessing as mp - self._mj_viewer = mujoco_py.MjRenderContextOffscreen( - self.wrapped_env.sim - ) - self._alt_viewer = SimpleImageViewer() + self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen( + self.wrapped_env.sim + ) - self._mj_viewer._set_mujoco_buffers() - self._mj_viewer.render(640, 480) - image = np.asarray( - self._mj_viewer.read_pixels(640, 480, depth=False)[::-1, :, :], - dtype=np.uint8, - ) - if mode == "rgb_array": - return image - else: - if not (image.min() == 0 and image.max() == 0): - self._alt_viewer.imshow(image) - return self._alt_viewer.isopen + class _ProcWorker(mp.Process): + def __init__(self, pipe: mp.connection.Pip) -> None: + super().__init__() + self.pipe = pipe + + def _run_server(self) -> None: + import asyncio + import io + import websockets + + async def handler(ws, _path): + loop = asyncio.get_running_loop() + while True: + image = await loop.run_in_executor(None, self.pipe.recv) + if image is None: + print("Shutting down the websocket server...") + return + with io.BytesIO() as stream: + image.save(stream, format="png") + await ws.send(stream.getvalue()) + + server = websockets.serve(handler, "127.0.0.1", 5678) + asyncio.get_event_loop().run_until_complete(server) + asyncio.get_event_loop().run_forever() + + def run(self) -> None: + try: + self._run_server() + except KeyboardInterrupt: + pass + except Exception as e: + print("Exception in websocket server") + raise e + + self._websock_server_pipe, pipe = mp.Pipe() + worker = _ProcWorker(pipe) + worker.start() + + def _render_image(self): + from PIL import Image + + self._mj_offscreen_viewer._set_mujoco_buffers() + self._mj_offscreen_viewer.render(640, 480) + image_array = np.asarray( + self._mj_offscreen_viewer.read_pixels(640, 480, depth=False)[::-1, :, :], + dtype=np.uint8, + ) + return Image.fromarray(image_array) + + def render(self, *args, **kwargs) -> Any: + if self._websock_port is not None: + if self._mj_offscreen_viewer is None: + self._setup_websocket_viewer() + self._websock_server_pipe.send(self._render_image()) + return None else: - return self.wrapped_env.render(mode=mode, **kwargs) + return self.wrapped_env.render(*args, **kwargs) @property def action_space(self): @@ -450,8 +490,8 @@ class MazeEnv(gym.Env): def close(self) -> None: self.wrapped_env.close() - if self._alt_viewer is not None: - self._alt_viewer.close() + if self._websock_server_pipe is not None: + self._websock_server_pipe.send(self._render_image()) def _add_object_ball(