diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 7e7a68f..3a455d0 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -362,66 +362,6 @@ class MazeEnv(gym.Env): else: return self.wrapped_env.viewer - def _setup_websock_server(self) -> None: - import multiprocessing as mp - - class _ProcWorker(mp.Process): - def __init__(self, pipe: mp.connection.Pipe, port: int) -> None: - super().__init__() - self.pipe = pipe - self.port = port - - def _run_server(self) -> None: - import asyncio - import io - import fastapi - import pathlib - import uvicorn - - from PIL import Image - - app = fastapi.FastAPI() - html_path = pathlib.Path(__file__).parent.joinpath("static/index.html") - html = html_path.read_text().replace("{{port}}", str(self.port)) - server = None - - @app.get("/") - async def get(): - return fastapi.responses.HTMLResponse(html) - - @app.websocket("/ws") - async def ws_send_image(websocket: fastapi.WebSocket): - await websocket.accept() - loop = asyncio.get_running_loop() - while True: - image_array = await loop.run_in_executor(None, self.pipe.recv) - if image_array is None: - break - image = Image.fromarray(image_array) - with io.BytesIO() as stream: - image.save(stream, format="png") - res = stream.getvalue() - await websocket.send_bytes(res) - await websocket.close() - server.should_exit = True - - config = uvicorn.Config(app, port=self.port) - server = uvicorn.Server(config) - server.run() - - def run(self) -> None: - try: - self._run_server() - except KeyboardInterrupt: - pass - except Exception as e: - print("Exception in websocket server") - raise e - - self._websock_server_pipe, pipe = mp.Pipe() - worker = _ProcWorker(pipe, self._websock_port) - worker.start() - def _render_image(self) -> np.ndarray: self._mj_offscreen_viewer._set_mujoco_buffers() self._mj_offscreen_viewer.render(640, 480) @@ -433,12 +373,12 @@ class MazeEnv(gym.Env): def render(self, mode="human", **kwargs) -> Any: if self._websock_port is not None: if self._mj_offscreen_viewer is None: - import mujoco_py + from mujoco_py import MjRenderContextOffscreen as MjRenderOffscreen - self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen( - self.wrapped_env.sim - ) - self._setup_websock_server() + from mujoco_maze.websock_viewer import start_server + + self._mj_offscreen_viewer = MjRenderOffscreen(self.wrapped_env.sim) + self._websock_server_pipe = start_server(self._websock_port) self._websock_server_pipe.send(self._render_image()) return True else: diff --git a/mujoco_maze/static/favicon.ico b/mujoco_maze/static/favicon.ico deleted file mode 100644 index 45acf83..0000000 Binary files a/mujoco_maze/static/favicon.ico and /dev/null differ diff --git a/mujoco_maze/websock_viewer.py b/mujoco_maze/websock_viewer.py new file mode 100644 index 0000000..9dad232 --- /dev/null +++ b/mujoco_maze/websock_viewer.py @@ -0,0 +1,65 @@ +import asyncio +import io +import multiprocessing as mp +import pathlib + +import fastapi +import uvicorn + +from PIL import Image + + +class _ServerWorker(mp.Process): + def __init__(self, pipe: mp.connection.Pipe, port: int) -> None: + super().__init__() + self.pipe = pipe + self.port = port + + def _run_server(self) -> None: + + app = fastapi.FastAPI() + static = pathlib.Path(__file__).parent.joinpath("static") + html_path = static.joinpath("index.html") + html = html_path.read_text().replace("{{port}}", str(self.port)) + + @app.get("/") + async def get(): + return fastapi.responses.HTMLResponse(html) + + server = None + + @app.websocket("/ws") + async def ws_send_image(websocket: fastapi.WebSocket): + await websocket.accept() + loop = asyncio.get_running_loop() + while True: + image_array = await loop.run_in_executor(None, self.pipe.recv) + if image_array is None: + break + image = Image.fromarray(image_array) + with io.BytesIO() as stream: + image.save(stream, format="png") + res = stream.getvalue() + await websocket.send_bytes(res) + await websocket.close() + server.should_exit = True + + config = uvicorn.Config(app, port=self.port) + server = uvicorn.Server(config) + server.run() + + def run(self) -> None: + try: + self._run_server() + except KeyboardInterrupt: + pass + except Exception as e: + print("Exception in websocket server") + raise e + + +def start_server(port: int) -> mp.connection.Connection: + mainproc_pipe, server_pipe = mp.Pipe() + worker = _ServerWorker(server_pipe, port) + worker.start() + return mainproc_pipe