Web-based viewer

This commit is contained in:
kngwyu 2021-04-13 19:18:50 +09:00
parent 9e16a68dc2
commit 15a1a1384f
3 changed files with 73 additions and 28 deletions

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
include mujoco-maze/static/index.html

View File

@ -33,7 +33,7 @@ class MazeEnv(gym.Env):
inner_reward_scaling: float = 1.0,
restitution_coef: float = 0.8,
task_kwargs: dict = {},
websock_render_port: Optional[int] = None,
websock_port: Optional[int] = None,
**kwargs,
) -> None:
self._task = maze_task(maze_size_scaling, **task_kwargs)
@ -362,37 +362,52 @@ class MazeEnv(gym.Env):
else:
return self.wrapped_env.viewer
def _setup_websocket_server(self) -> None:
def _setup_websock_server(self) -> None:
import multiprocessing as mp
self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
self.wrapped_env.sim
)
class _ProcWorker(mp.Process):
def __init__(self, pipe: mp.connection.Pip) -> None:
def __init__(self, pipe: mp.connection.Pipe, port: int) -> None:
super().__init__()
self.pipe = pipe
self.port = port
self.server = None
def _run_server(self) -> None:
import asyncio
import io
import websockets
import fastapi
import pathlib
import uvicorn
async def handler(ws, _path):
from PIL import Image
app = fastapi.FastAPI()
html_path = pathlib.Path(__file__).parent.joinpath("static/index.html")
html = html_path.read_text().replace("{{port}}", str(self.port))
@app.get("/")
async def get():
return fastapi.responses.HTMLResponse(html)
@app.websocket("/ws")
async def ws_send_image(websocket: fastapi.WebSocket):
await websocket.accept()
loop = asyncio.get_running_loop()
while True:
image = await loop.run_in_executor(None, self.pipe.recv)
if image is None:
print("Shutting down the websocket server...")
return
image_array = await loop.run_in_executor(None, self.pipe.recv)
if image_array is None:
break
image = Image.fromarray(image_array)
with io.BytesIO() as stream:
image.save(stream, format="png")
await ws.send(stream.getvalue())
res = stream.getvalue()
await websocket.send_bytes(res)
await websocket.close()
await self.server.shutdown()
server = websockets.serve(handler, "127.0.0.1", 5678)
asyncio.get_event_loop().run_until_complete(server)
asyncio.get_event_loop().run_forever()
config = uvicorn.Config(app, port=self.port)
self.server = uvicorn.Server(config)
self.server.run()
def run(self) -> None:
try:
@ -404,28 +419,30 @@ class MazeEnv(gym.Env):
raise e
self._websock_server_pipe, pipe = mp.Pipe()
worker = _ProcWorker(pipe)
worker = _ProcWorker(pipe, self._websock_port)
worker.start()
def _render_image(self):
from PIL import Image
def _render_image(self) -> np.ndarray:
self._mj_offscreen_viewer._set_mujoco_buffers()
self._mj_offscreen_viewer.render(640, 480)
image_array = np.asarray(
return np.asarray(
self._mj_offscreen_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
dtype=np.uint8,
)
return Image.fromarray(image_array)
def render(self, *args, **kwargs) -> Any:
def render(self, mode="human", **kwargs) -> Any:
if self._websock_port is not None:
if self._mj_offscreen_viewer is None:
self._setup_websocket_viewer()
import mujoco_py
self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
self.wrapped_env.sim
)
self._setup_websock_server()
self._websock_server_pipe.send(self._render_image())
return None
return True
else:
return self.wrapped_env.render(*args, **kwargs)
return self.wrapped_env.render(mode, **kwargs)
@property
def action_space(self):
@ -491,7 +508,7 @@ class MazeEnv(gym.Env):
def close(self) -> None:
self.wrapped_env.close()
if self._websock_server_pipe is not None:
self._websock_server_pipe.send(self._render_image())
self._websock_server_pipe.send(None)
def _add_object_ball(

View File

@ -0,0 +1,27 @@
<!DOCTYPE html>
<html>
<head>
<title>MuJoCo maze visualizer</title>
</head>
<body>
<script>
var web_socket = new WebSocket('ws://127.0.0.1:{{port}}/ws');
web_socket.binaryType = "arraybuffer";
web_socket.onmessage = function(event) {
var canvas = document.getElementById('canvas');
var ctx = canvas.getContext('2d');
var blob = new Blob([event.data], {type:'image/png'});
var url = URL.createObjectURL(blob);
var image = new Image();
image.onload = function() {
ctx.drawImage(image, 0, 0);
}
console.log(url);
image.src = url;
}
</script>
<div>
<canvas id="canvas" width="600" height="480"></canvas>
</div>
</body>
</html>