Websocket viewer

This commit is contained in:
kngwyu 2021-04-13 18:18:39 +09:00
parent e12e38e64e
commit 9e16a68dc2

View File

@ -10,7 +10,7 @@ import itertools as it
import os
import tempfile
import xml.etree.ElementTree as ET
from typing import Any, List, Tuple, Type
from typing import Any, List, Optional, Tuple, Type
import gym
import numpy as np
@ -33,7 +33,7 @@ class MazeEnv(gym.Env):
inner_reward_scaling: float = 1.0,
restitution_coef: float = 0.8,
task_kwargs: dict = {},
use_alt_viewer: bool = False,
websock_render_port: Optional[int] = None,
**kwargs,
) -> None:
self._task = maze_task(maze_size_scaling, **task_kwargs)
@ -191,8 +191,9 @@ class MazeEnv(gym.Env):
self.world_tree = tree
self.wrapped_env = model_cls(file_path=file_path, **kwargs)
self.observation_space = self._get_obs_space()
self._use_alt_viewer = use_alt_viewer
self._alt_viewer = None
self._websock_port = websock_port
self._mj_offscreen_viewer = None
self._websock_server_pipe = None
@property
def has_extended_obs(self) -> bool:
@ -356,36 +357,75 @@ class MazeEnv(gym.Env):
@property
def viewer(self) -> Any:
if self._use_alt_viewer:
return self._alt_viewer
if self._websock_port is not None:
return self._mj_viewer
else:
return self.wrapped_env.viewer
def render(self, mode="human", **kwargs) -> Any:
if self._use_alt_viewer:
if self._alt_viewer is None:
from gym.envs.classic_control.rendering import SimpleImageViewer
import mujoco_py
def _setup_websocket_server(self) -> None:
import multiprocessing as mp
self._mj_viewer = mujoco_py.MjRenderContextOffscreen(
self.wrapped_env.sim
)
self._alt_viewer = SimpleImageViewer()
self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
self.wrapped_env.sim
)
self._mj_viewer._set_mujoco_buffers()
self._mj_viewer.render(640, 480)
image = np.asarray(
self._mj_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
dtype=np.uint8,
)
if mode == "rgb_array":
return image
else:
if not (image.min() == 0 and image.max() == 0):
self._alt_viewer.imshow(image)
return self._alt_viewer.isopen
class _ProcWorker(mp.Process):
def __init__(self, pipe: mp.connection.Pip) -> None:
super().__init__()
self.pipe = pipe
def _run_server(self) -> None:
import asyncio
import io
import websockets
async def handler(ws, _path):
loop = asyncio.get_running_loop()
while True:
image = await loop.run_in_executor(None, self.pipe.recv)
if image is None:
print("Shutting down the websocket server...")
return
with io.BytesIO() as stream:
image.save(stream, format="png")
await ws.send(stream.getvalue())
server = websockets.serve(handler, "127.0.0.1", 5678)
asyncio.get_event_loop().run_until_complete(server)
asyncio.get_event_loop().run_forever()
def run(self) -> None:
try:
self._run_server()
except KeyboardInterrupt:
pass
except Exception as e:
print("Exception in websocket server")
raise e
self._websock_server_pipe, pipe = mp.Pipe()
worker = _ProcWorker(pipe)
worker.start()
def _render_image(self):
from PIL import Image
self._mj_offscreen_viewer._set_mujoco_buffers()
self._mj_offscreen_viewer.render(640, 480)
image_array = np.asarray(
self._mj_offscreen_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
dtype=np.uint8,
)
return Image.fromarray(image_array)
def render(self, *args, **kwargs) -> Any:
if self._websock_port is not None:
if self._mj_offscreen_viewer is None:
self._setup_websocket_viewer()
self._websock_server_pipe.send(self._render_image())
return None
else:
return self.wrapped_env.render(mode=mode, **kwargs)
return self.wrapped_env.render(*args, **kwargs)
@property
def action_space(self):
@ -450,8 +490,8 @@ class MazeEnv(gym.Env):
def close(self) -> None:
self.wrapped_env.close()
if self._alt_viewer is not None:
self._alt_viewer.close()
if self._websock_server_pipe is not None:
self._websock_server_pipe.send(self._render_image())
def _add_object_ball(