Websocket viewer
This commit is contained in:
parent
e12e38e64e
commit
9e16a68dc2
@ -10,7 +10,7 @@ import itertools as it
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Any, List, Tuple, Type
|
from typing import Any, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -33,7 +33,7 @@ class MazeEnv(gym.Env):
|
|||||||
inner_reward_scaling: float = 1.0,
|
inner_reward_scaling: float = 1.0,
|
||||||
restitution_coef: float = 0.8,
|
restitution_coef: float = 0.8,
|
||||||
task_kwargs: dict = {},
|
task_kwargs: dict = {},
|
||||||
use_alt_viewer: bool = False,
|
websock_render_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
||||||
@ -191,8 +191,9 @@ class MazeEnv(gym.Env):
|
|||||||
self.world_tree = tree
|
self.world_tree = tree
|
||||||
self.wrapped_env = model_cls(file_path=file_path, **kwargs)
|
self.wrapped_env = model_cls(file_path=file_path, **kwargs)
|
||||||
self.observation_space = self._get_obs_space()
|
self.observation_space = self._get_obs_space()
|
||||||
self._use_alt_viewer = use_alt_viewer
|
self._websock_port = websock_port
|
||||||
self._alt_viewer = None
|
self._mj_offscreen_viewer = None
|
||||||
|
self._websock_server_pipe = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_extended_obs(self) -> bool:
|
def has_extended_obs(self) -> bool:
|
||||||
@ -356,36 +357,75 @@ class MazeEnv(gym.Env):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def viewer(self) -> Any:
|
def viewer(self) -> Any:
|
||||||
if self._use_alt_viewer:
|
if self._websock_port is not None:
|
||||||
return self._alt_viewer
|
return self._mj_viewer
|
||||||
else:
|
else:
|
||||||
return self.wrapped_env.viewer
|
return self.wrapped_env.viewer
|
||||||
|
|
||||||
def render(self, mode="human", **kwargs) -> Any:
|
def _setup_websocket_server(self) -> None:
|
||||||
if self._use_alt_viewer:
|
import multiprocessing as mp
|
||||||
if self._alt_viewer is None:
|
|
||||||
from gym.envs.classic_control.rendering import SimpleImageViewer
|
|
||||||
import mujoco_py
|
|
||||||
|
|
||||||
self._mj_viewer = mujoco_py.MjRenderContextOffscreen(
|
self._mj_offscreen_viewer = mujoco_py.MjRenderContextOffscreen(
|
||||||
self.wrapped_env.sim
|
self.wrapped_env.sim
|
||||||
)
|
)
|
||||||
self._alt_viewer = SimpleImageViewer()
|
|
||||||
|
|
||||||
self._mj_viewer._set_mujoco_buffers()
|
class _ProcWorker(mp.Process):
|
||||||
self._mj_viewer.render(640, 480)
|
def __init__(self, pipe: mp.connection.Pip) -> None:
|
||||||
image = np.asarray(
|
super().__init__()
|
||||||
self._mj_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
|
self.pipe = pipe
|
||||||
|
|
||||||
|
def _run_server(self) -> None:
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
async def handler(ws, _path):
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
while True:
|
||||||
|
image = await loop.run_in_executor(None, self.pipe.recv)
|
||||||
|
if image is None:
|
||||||
|
print("Shutting down the websocket server...")
|
||||||
|
return
|
||||||
|
with io.BytesIO() as stream:
|
||||||
|
image.save(stream, format="png")
|
||||||
|
await ws.send(stream.getvalue())
|
||||||
|
|
||||||
|
server = websockets.serve(handler, "127.0.0.1", 5678)
|
||||||
|
asyncio.get_event_loop().run_until_complete(server)
|
||||||
|
asyncio.get_event_loop().run_forever()
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
try:
|
||||||
|
self._run_server()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception in websocket server")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self._websock_server_pipe, pipe = mp.Pipe()
|
||||||
|
worker = _ProcWorker(pipe)
|
||||||
|
worker.start()
|
||||||
|
|
||||||
|
def _render_image(self):
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
self._mj_offscreen_viewer._set_mujoco_buffers()
|
||||||
|
self._mj_offscreen_viewer.render(640, 480)
|
||||||
|
image_array = np.asarray(
|
||||||
|
self._mj_offscreen_viewer.read_pixels(640, 480, depth=False)[::-1, :, :],
|
||||||
dtype=np.uint8,
|
dtype=np.uint8,
|
||||||
)
|
)
|
||||||
if mode == "rgb_array":
|
return Image.fromarray(image_array)
|
||||||
return image
|
|
||||||
|
def render(self, *args, **kwargs) -> Any:
|
||||||
|
if self._websock_port is not None:
|
||||||
|
if self._mj_offscreen_viewer is None:
|
||||||
|
self._setup_websocket_viewer()
|
||||||
|
self._websock_server_pipe.send(self._render_image())
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
if not (image.min() == 0 and image.max() == 0):
|
return self.wrapped_env.render(*args, **kwargs)
|
||||||
self._alt_viewer.imshow(image)
|
|
||||||
return self._alt_viewer.isopen
|
|
||||||
else:
|
|
||||||
return self.wrapped_env.render(mode=mode, **kwargs)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
@ -450,8 +490,8 @@ class MazeEnv(gym.Env):
|
|||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.wrapped_env.close()
|
self.wrapped_env.close()
|
||||||
if self._alt_viewer is not None:
|
if self._websock_server_pipe is not None:
|
||||||
self._alt_viewer.close()
|
self._websock_server_pipe.send(self._render_image())
|
||||||
|
|
||||||
|
|
||||||
def _add_object_ball(
|
def _add_object_ball(
|
||||||
|
Loading…
Reference in New Issue
Block a user