Web-based viewer
This commit is contained in:
parent
9e16a68dc2
commit
15a1a1384f
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@ -0,0 +1 @@
|
||||
include mujoco-maze/static/index.html
|
@ -33,7 +33,7 @@ class MazeEnv(gym.Env):
|
||||
inner_reward_scaling: float = 1.0,
|
||||
restitution_coef: float = 0.8,
|
||||
task_kwargs: dict = {},
|
||||
websock_render_port: Optional[int] = None,
|
||||
websock_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
||||
@ -362,37 +362,52 @@ class MazeEnv(gym.Env):
|
||||
else:
|
||||
return self.wrapped_env.viewer
|
||||
|
||||
def _setup_websocket_server(self) -> None:
|
||||
def _setup_websock_server(self) -> None:
|
||||
import multiprocessing as mp
|
||||
|
||||
self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
|
||||
self.wrapped_env.sim
|
||||
)
|
||||
|
||||
class _ProcWorker(mp.Process):
|
||||
def __init__(self, pipe: mp.connection.Pip) -> None:
|
||||
def __init__(self, pipe: mp.connection.Pipe, port: int) -> None:
|
||||
super().__init__()
|
||||
self.pipe = pipe
|
||||
self.port = port
|
||||
self.server = None
|
||||
|
||||
def _run_server(self) -> None:
|
||||
import asyncio
|
||||
import io
|
||||
import websockets
|
||||
import fastapi
|
||||
import pathlib
|
||||
import uvicorn
|
||||
|
||||
async def handler(ws, _path):
|
||||
from PIL import Image
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
html_path = pathlib.Path(__file__).parent.joinpath("static/index.html")
|
||||
html = html_path.read_text().replace("{{port}}", str(self.port))
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return fastapi.responses.HTMLResponse(html)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def ws_send_image(websocket: fastapi.WebSocket):
|
||||
await websocket.accept()
|
||||
loop = asyncio.get_running_loop()
|
||||
while True:
|
||||
image = await loop.run_in_executor(None, self.pipe.recv)
|
||||
if image is None:
|
||||
print("Shutting down the websocket server...")
|
||||
return
|
||||
image_array = await loop.run_in_executor(None, self.pipe.recv)
|
||||
if image_array is None:
|
||||
break
|
||||
image = Image.fromarray(image_array)
|
||||
with io.BytesIO() as stream:
|
||||
image.save(stream, format="png")
|
||||
await ws.send(stream.getvalue())
|
||||
res = stream.getvalue()
|
||||
await websocket.send_bytes(res)
|
||||
await websocket.close()
|
||||
await self.server.shutdown()
|
||||
|
||||
server = websockets.serve(handler, "127.0.0.1", 5678)
|
||||
asyncio.get_event_loop().run_until_complete(server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
config = uvicorn.Config(app, port=self.port)
|
||||
self.server = uvicorn.Server(config)
|
||||
self.server.run()
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
@ -404,28 +419,30 @@ class MazeEnv(gym.Env):
|
||||
raise e
|
||||
|
||||
self._websock_server_pipe, pipe = mp.Pipe()
|
||||
worker = _ProcWorker(pipe)
|
||||
worker = _ProcWorker(pipe, self._websock_port)
|
||||
worker.start()
|
||||
|
||||
def _render_image(self):
|
||||
from PIL import Image
|
||||
|
||||
def _render_image(self) -> np.ndarray:
|
||||
self._mj_offscreen_viewer._set_mujoco_buffers()
|
||||
self._mj_offscreen_viewer.render(640, 480)
|
||||
image_array = np.asarray(
|
||||
return np.asarray(
|
||||
self._mj_offscreen_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
return Image.fromarray(image_array)
|
||||
|
||||
def render(self, *args, **kwargs) -> Any:
|
||||
def render(self, mode="human", **kwargs) -> Any:
|
||||
if self._websock_port is not None:
|
||||
if self._mj_offscreen_viewer is None:
|
||||
self._setup_websocket_viewer()
|
||||
import mujoco_py
|
||||
|
||||
self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
|
||||
self.wrapped_env.sim
|
||||
)
|
||||
self._setup_websock_server()
|
||||
self._websock_server_pipe.send(self._render_image())
|
||||
return None
|
||||
return True
|
||||
else:
|
||||
return self.wrapped_env.render(*args, **kwargs)
|
||||
return self.wrapped_env.render(mode, **kwargs)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
@ -491,7 +508,7 @@ class MazeEnv(gym.Env):
|
||||
def close(self) -> None:
|
||||
self.wrapped_env.close()
|
||||
if self._websock_server_pipe is not None:
|
||||
self._websock_server_pipe.send(self._render_image())
|
||||
self._websock_server_pipe.send(None)
|
||||
|
||||
|
||||
def _add_object_ball(
|
||||
|
27
mujoco_maze/static/index.html
Normal file
27
mujoco_maze/static/index.html
Normal file
@ -0,0 +1,27 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>MuJoCo maze visualizer</title>
|
||||
</head>
|
||||
<body>
|
||||
<script>
|
||||
var web_socket = new WebSocket('ws://127.0.0.1:{{port}}/ws');
|
||||
web_socket.binaryType = "arraybuffer";
|
||||
web_socket.onmessage = function(event) {
|
||||
var canvas = document.getElementById('canvas');
|
||||
var ctx = canvas.getContext('2d');
|
||||
var blob = new Blob([event.data], {type:'image/png'});
|
||||
var url = URL.createObjectURL(blob);
|
||||
var image = new Image();
|
||||
image.onload = function() {
|
||||
ctx.drawImage(image, 0, 0);
|
||||
}
|
||||
console.log(url);
|
||||
image.src = url;
|
||||
}
|
||||
</script>
|
||||
<div>
|
||||
<canvas id="canvas" width="600" height="480"></canvas>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
Loading…
Reference in New Issue
Block a user