Websocket viewer
This commit is contained in:
parent
e12e38e64e
commit
9e16a68dc2
@ -10,7 +10,7 @@ import itertools as it
|
||||
import os
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, List, Tuple, Type
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -33,7 +33,7 @@ class MazeEnv(gym.Env):
|
||||
inner_reward_scaling: float = 1.0,
|
||||
restitution_coef: float = 0.8,
|
||||
task_kwargs: dict = {},
|
||||
use_alt_viewer: bool = False,
|
||||
websock_render_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
||||
@ -191,8 +191,9 @@ class MazeEnv(gym.Env):
|
||||
self.world_tree = tree
|
||||
self.wrapped_env = model_cls(file_path=file_path, **kwargs)
|
||||
self.observation_space = self._get_obs_space()
|
||||
self._use_alt_viewer = use_alt_viewer
|
||||
self._alt_viewer = None
|
||||
self._websock_port = websock_port
|
||||
self._mj_offscreen_viewer = None
|
||||
self._websock_server_pipe = None
|
||||
|
||||
@property
|
||||
def has_extended_obs(self) -> bool:
|
||||
@ -356,36 +357,75 @@ class MazeEnv(gym.Env):
|
||||
|
||||
@property
|
||||
def viewer(self) -> Any:
|
||||
if self._use_alt_viewer:
|
||||
return self._alt_viewer
|
||||
if self._websock_port is not None:
|
||||
return self._mj_viewer
|
||||
else:
|
||||
return self.wrapped_env.viewer
|
||||
|
||||
def render(self, mode="human", **kwargs) -> Any:
|
||||
if self._use_alt_viewer:
|
||||
if self._alt_viewer is None:
|
||||
from gym.envs.classic_control.rendering import SimpleImageViewer
|
||||
import mujoco_py
|
||||
def _setup_websocket_server(self) -> None:
|
||||
import multiprocessing as mp
|
||||
|
||||
self._mj_viewer = mujoco_py.MjRenderContextOffscreen(
|
||||
self.wrapped_env.sim
|
||||
)
|
||||
self._alt_viewer = SimpleImageViewer()
|
||||
self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
|
||||
self.wrapped_env.sim
|
||||
)
|
||||
|
||||
self._mj_viewer._set_mujoco_buffers()
|
||||
self._mj_viewer.render(640, 480)
|
||||
image = np.asarray(
|
||||
self._mj_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
if mode == "rgb_array":
|
||||
return image
|
||||
else:
|
||||
if not (image.min() == 0 and image.max() == 0):
|
||||
self._alt_viewer.imshow(image)
|
||||
return self._alt_viewer.isopen
|
||||
class _ProcWorker(mp.Process):
|
||||
def __init__(self, pipe: mp.connection.Pip) -> None:
|
||||
super().__init__()
|
||||
self.pipe = pipe
|
||||
|
||||
def _run_server(self) -> None:
|
||||
import asyncio
|
||||
import io
|
||||
import websockets
|
||||
|
||||
async def handler(ws, _path):
|
||||
loop = asyncio.get_running_loop()
|
||||
while True:
|
||||
image = await loop.run_in_executor(None, self.pipe.recv)
|
||||
if image is None:
|
||||
print("Shutting down the websocket server...")
|
||||
return
|
||||
with io.BytesIO() as stream:
|
||||
image.save(stream, format="png")
|
||||
await ws.send(stream.getvalue())
|
||||
|
||||
server = websockets.serve(handler, "127.0.0.1", 5678)
|
||||
asyncio.get_event_loop().run_until_complete(server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self._run_server()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
print("Exception in websocket server")
|
||||
raise e
|
||||
|
||||
self._websock_server_pipe, pipe = mp.Pipe()
|
||||
worker = _ProcWorker(pipe)
|
||||
worker.start()
|
||||
|
||||
def _render_image(self):
|
||||
from PIL import Image
|
||||
|
||||
self._mj_offscreen_viewer._set_mujoco_buffers()
|
||||
self._mj_offscreen_viewer.render(640, 480)
|
||||
image_array = np.asarray(
|
||||
self._mj_offscreen_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
return Image.fromarray(image_array)
|
||||
|
||||
def render(self, *args, **kwargs) -> Any:
|
||||
if self._websock_port is not None:
|
||||
if self._mj_offscreen_viewer is None:
|
||||
self._setup_websocket_viewer()
|
||||
self._websock_server_pipe.send(self._render_image())
|
||||
return None
|
||||
else:
|
||||
return self.wrapped_env.render(mode=mode, **kwargs)
|
||||
return self.wrapped_env.render(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
@ -450,8 +490,8 @@ class MazeEnv(gym.Env):
|
||||
|
||||
def close(self) -> None:
|
||||
self.wrapped_env.close()
|
||||
if self._alt_viewer is not None:
|
||||
self._alt_viewer.close()
|
||||
if self._websock_server_pipe is not None:
|
||||
self._websock_server_pipe.send(self._render_image())
|
||||
|
||||
|
||||
def _add_object_ball(
|
||||
|
Loading…
Reference in New Issue
Block a user