Websocket viewer

This commit is contained in:
kngwyu 2021-04-13 18:18:39 +09:00
parent e12e38e64e
commit 9e16a68dc2

View File

@ -10,7 +10,7 @@ import itertools as it
import os import os
import tempfile import tempfile
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Any, List, Tuple, Type from typing import Any, List, Optional, Tuple, Type
import gym import gym
import numpy as np import numpy as np
@ -33,7 +33,7 @@ class MazeEnv(gym.Env):
inner_reward_scaling: float = 1.0, inner_reward_scaling: float = 1.0,
restitution_coef: float = 0.8, restitution_coef: float = 0.8,
task_kwargs: dict = {}, task_kwargs: dict = {},
use_alt_viewer: bool = False, websock_render_port: Optional[int] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
self._task = maze_task(maze_size_scaling, **task_kwargs) self._task = maze_task(maze_size_scaling, **task_kwargs)
@ -191,8 +191,9 @@ class MazeEnv(gym.Env):
self.world_tree = tree self.world_tree = tree
self.wrapped_env = model_cls(file_path=file_path, **kwargs) self.wrapped_env = model_cls(file_path=file_path, **kwargs)
self.observation_space = self._get_obs_space() self.observation_space = self._get_obs_space()
self._use_alt_viewer = use_alt_viewer self._websock_port = websock_port
self._alt_viewer = None self._mj_offscreen_viewer = None
self._websock_server_pipe = None
@property @property
def has_extended_obs(self) -> bool: def has_extended_obs(self) -> bool:
@ -356,36 +357,75 @@ class MazeEnv(gym.Env):
@property @property
def viewer(self) -> Any: def viewer(self) -> Any:
if self._use_alt_viewer: if self._websock_port is not None:
return self._alt_viewer return self._mj_viewer
else: else:
return self.wrapped_env.viewer return self.wrapped_env.viewer
def render(self, mode="human", **kwargs) -> Any: def _setup_websocket_server(self) -> None:
if self._use_alt_viewer: import multiprocessing as mp
if self._alt_viewer is None:
from gym.envs.classic_control.rendering import SimpleImageViewer
import mujoco_py
self._mj_viewer = mujoco_py.MjRenderContextOffscreen( self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
self.wrapped_env.sim self.wrapped_env.sim
) )
self._alt_viewer = SimpleImageViewer()
self._mj_viewer._set_mujoco_buffers() class _ProcWorker(mp.Process):
self._mj_viewer.render(640, 480) def __init__(self, pipe: mp.connection.Pip) -> None:
image = np.asarray( super().__init__()
self._mj_viewer.read_pixels(640, 480, depth=False)[::-1, :, :], self.pipe = pipe
def _run_server(self) -> None:
import asyncio
import io
import websockets
async def handler(ws, _path):
loop = asyncio.get_running_loop()
while True:
image = await loop.run_in_executor(None, self.pipe.recv)
if image is None:
print("Shutting down the websocket server...")
return
with io.BytesIO() as stream:
image.save(stream, format="png")
await ws.send(stream.getvalue())
server = websockets.serve(handler, "127.0.0.1", 5678)
asyncio.get_event_loop().run_until_complete(server)
asyncio.get_event_loop().run_forever()
def run(self) -> None:
try:
self._run_server()
except KeyboardInterrupt:
pass
except Exception as e:
print("Exception in websocket server")
raise e
self._websock_server_pipe, pipe = mp.Pipe()
worker = _ProcWorker(pipe)
worker.start()
def _render_image(self):
from PIL import Image
self._mj_offscreen_viewer._set_mujoco_buffers()
self._mj_offscreen_viewer.render(640, 480)
image_array = np.asarray(
self._mj_offscreen_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
dtype=np.uint8, dtype=np.uint8,
) )
if mode == "rgb_array": return Image.fromarray(image_array)
return image
def render(self, *args, **kwargs) -> Any:
if self._websock_port is not None:
if self._mj_offscreen_viewer is None:
self._setup_websocket_viewer()
self._websock_server_pipe.send(self._render_image())
return None
else: else:
if not (image.min() == 0 and image.max() == 0): return self.wrapped_env.render(*args, **kwargs)
self._alt_viewer.imshow(image)
return self._alt_viewer.isopen
else:
return self.wrapped_env.render(mode=mode, **kwargs)
@property @property
def action_space(self): def action_space(self):
@ -450,8 +490,8 @@ class MazeEnv(gym.Env):
def close(self) -> None: def close(self) -> None:
self.wrapped_env.close() self.wrapped_env.close()
if self._alt_viewer is not None: if self._websock_server_pipe is not None:
self._alt_viewer.close() self._websock_server_pipe.send(self._render_image())
def _add_object_ball( def _add_object_ball(