Alternative rendering
This commit is contained in:
parent
72fa14d786
commit
100432fceb
92
mujoco_maze/alt_rendering.py
Normal file
92
mujoco_maze/alt_rendering.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
2D rendering framework
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if "Apple" in sys.version:
|
||||||
|
if "DYLD_FALLBACK_LIBRARY_PATH" in os.environ:
|
||||||
|
os.environ["DYLD_FALLBACK_LIBRARY_PATH"] += ":/usr/lib"
|
||||||
|
# (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite
|
||||||
|
|
||||||
|
from gym import error
|
||||||
|
|
||||||
|
|
||||||
|
import pyglet
|
||||||
|
|
||||||
|
from pyglet import gl
|
||||||
|
|
||||||
|
|
||||||
|
RAD2DEG = 57.29577951308232
|
||||||
|
|
||||||
|
|
||||||
|
def get_display(spec):
|
||||||
|
"""Convert a display specification (such as :0) into an actual Display
|
||||||
|
object.
|
||||||
|
|
||||||
|
Pyglet only supports multiple Displays on Linux.
|
||||||
|
"""
|
||||||
|
if spec is None:
|
||||||
|
return pyglet.canvas.get_display()
|
||||||
|
# returns already available pyglet_display,
|
||||||
|
# if there is no pyglet display available then it creates one
|
||||||
|
elif isinstance(spec, str):
|
||||||
|
return pyglet.canvas.Display(spec)
|
||||||
|
else:
|
||||||
|
raise error.Error(
|
||||||
|
f"Invalid display specification: {spec}. (Must be a string like :0 or None.)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_window(width, height, display, **kwargs):
|
||||||
|
"""
|
||||||
|
Will create a pyglet window from the display specification provided.
|
||||||
|
"""
|
||||||
|
screen = display.get_screens() # available screens
|
||||||
|
config = screen[0].get_best_config() # selecting the first screen
|
||||||
|
context = config.create_context(None) # create GL context
|
||||||
|
|
||||||
|
return pyglet.window.Window(
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
display=display,
|
||||||
|
config=config,
|
||||||
|
context=context,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageViewer:
|
||||||
|
def __init__(self, width, height, display=None):
|
||||||
|
display = get_display(display)
|
||||||
|
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.window = get_window(width=width, height=height, display=display)
|
||||||
|
self.window.on_close = self.window_closed_by_user
|
||||||
|
self.isopen = True
|
||||||
|
|
||||||
|
gl.glEnable(gl.GL_BLEND)
|
||||||
|
gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
|
||||||
|
|
||||||
|
def imshow(self, arr):
|
||||||
|
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
|
||||||
|
self.window.clear()
|
||||||
|
self.window.switch_to()
|
||||||
|
self.window.dispatch_events()
|
||||||
|
image = pyglet.image.ImageData(
|
||||||
|
arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3
|
||||||
|
)
|
||||||
|
image.blit(0, 0) # draw
|
||||||
|
self.window.flip()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.isopen and sys.meta_path:
|
||||||
|
self.window.close()
|
||||||
|
self.isopen = False
|
||||||
|
|
||||||
|
def window_closed_by_user(self):
|
||||||
|
self.isopen = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
@ -10,7 +10,7 @@ import itertools as it
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import List, Tuple, Type
|
from typing import Any, List, Tuple, Type
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -33,7 +33,7 @@ class MazeEnv(gym.Env):
|
|||||||
inner_reward_scaling: float = 1.0,
|
inner_reward_scaling: float = 1.0,
|
||||||
restitution_coef: float = 0.8,
|
restitution_coef: float = 0.8,
|
||||||
task_kwargs: dict = {},
|
task_kwargs: dict = {},
|
||||||
*args,
|
use_alt_viewer: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
self._task = maze_task(maze_size_scaling, **task_kwargs)
|
||||||
@ -70,11 +70,19 @@ class MazeEnv(gym.Env):
|
|||||||
if model_cls.RADIUS is None:
|
if model_cls.RADIUS is None:
|
||||||
raise ValueError("Manual collision needs radius of the model")
|
raise ValueError("Manual collision needs radius of the model")
|
||||||
self._collision = maze_env_utils.CollisionDetector(
|
self._collision = maze_env_utils.CollisionDetector(
|
||||||
structure, size_scaling, torso_x, torso_y, model_cls.RADIUS,
|
structure,
|
||||||
|
size_scaling,
|
||||||
|
torso_x,
|
||||||
|
torso_y,
|
||||||
|
model_cls.RADIUS,
|
||||||
)
|
)
|
||||||
# Now all object balls have size=1.0
|
# Now all object balls have size=1.0
|
||||||
self._objball_collision = maze_env_utils.CollisionDetector(
|
self._objball_collision = maze_env_utils.CollisionDetector(
|
||||||
structure, size_scaling, torso_x, torso_y, self._task.OBJECT_BALL_SIZE,
|
structure,
|
||||||
|
size_scaling,
|
||||||
|
torso_x,
|
||||||
|
torso_y,
|
||||||
|
self._task.OBJECT_BALL_SIZE,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._collision = None
|
self._collision = None
|
||||||
@ -141,7 +149,15 @@ class MazeEnv(gym.Env):
|
|||||||
# Movable block.
|
# Movable block.
|
||||||
self.movable_blocks.append(f"movable_{i}_{j}")
|
self.movable_blocks.append(f"movable_{i}_{j}")
|
||||||
_add_movable_block(
|
_add_movable_block(
|
||||||
worldbody, struct, i, j, size_scaling, x, y, h, height_offset,
|
worldbody,
|
||||||
|
struct,
|
||||||
|
i,
|
||||||
|
j,
|
||||||
|
size_scaling,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
h,
|
||||||
|
height_offset,
|
||||||
)
|
)
|
||||||
elif struct.is_object_ball():
|
elif struct.is_object_ball():
|
||||||
# Movable Ball
|
# Movable Ball
|
||||||
@ -173,8 +189,10 @@ class MazeEnv(gym.Env):
|
|||||||
_, file_path = tempfile.mkstemp(text=True, suffix=".xml")
|
_, file_path = tempfile.mkstemp(text=True, suffix=".xml")
|
||||||
tree.write(file_path)
|
tree.write(file_path)
|
||||||
self.world_tree = tree
|
self.world_tree = tree
|
||||||
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
|
self.wrapped_env = model_cls(file_path=file_path, **kwargs)
|
||||||
self.observation_space = self._get_obs_space()
|
self.observation_space = self._get_obs_space()
|
||||||
|
self._use_alt_viewer = use_alt_viewer
|
||||||
|
self._alt_viewer = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_extended_obs(self) -> bool:
|
def has_extended_obs(self) -> bool:
|
||||||
@ -337,11 +355,26 @@ class MazeEnv(gym.Env):
|
|||||||
self.data.site_xpos[idx][: len(goal.pos)] = goal.pos
|
self.data.site_xpos[idx][: len(goal.pos)] = goal.pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def viewer(self):
|
def viewer(self) -> Any:
|
||||||
|
if self._use_alt_viewer:
|
||||||
|
return self._alt_viewer
|
||||||
|
else:
|
||||||
return self.wrapped_env.viewer
|
return self.wrapped_env.viewer
|
||||||
|
|
||||||
def render(self, *args, **kwargs):
|
def render(self, mode="human", **kwargs) -> Any:
|
||||||
return self.wrapped_env.render(*args, **kwargs)
|
if self._use_alt_viewer:
|
||||||
|
image = self.wrapped_env.sim.render(640, 480)[::-1, :, :]
|
||||||
|
if self._alt_viewer is None:
|
||||||
|
from mujoco_maze.alt_rendering import ImageViewer
|
||||||
|
|
||||||
|
self._alt_viewer = ImageViewer(640, 480)
|
||||||
|
if mode == "rgb_array":
|
||||||
|
return image
|
||||||
|
else:
|
||||||
|
self._alt_viewer.imshow(image)
|
||||||
|
return self._alt_viewer.isopen
|
||||||
|
else:
|
||||||
|
return self.wrapped_env.render(mode=mode, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
@ -406,6 +439,8 @@ class MazeEnv(gym.Env):
|
|||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.wrapped_env.close()
|
self.wrapped_env.close()
|
||||||
|
if self._alt_viewer is not None:
|
||||||
|
self._alt_viewer.close()
|
||||||
|
|
||||||
|
|
||||||
def _add_object_ball(
|
def _add_object_ball(
|
||||||
@ -479,7 +514,10 @@ def _add_movable_block(
|
|||||||
shrink = 1.0
|
shrink = 1.0
|
||||||
size = size_scaling * 0.5 * shrink
|
size = size_scaling * 0.5 * shrink
|
||||||
movable_body = ET.SubElement(
|
movable_body = ET.SubElement(
|
||||||
worldbody, "body", name=f"movable_{i}_{j}", pos=f"{x} {y} {h}",
|
worldbody,
|
||||||
|
"body",
|
||||||
|
name=f"movable_{i}_{j}",
|
||||||
|
pos=f"{x} {y} {h}",
|
||||||
)
|
)
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
movable_body,
|
movable_body,
|
||||||
|
@ -22,7 +22,7 @@ class PointEnv(AgentModel):
|
|||||||
|
|
||||||
VELOCITY_LIMITS: float = 10.0
|
VELOCITY_LIMITS: float = 10.0
|
||||||
|
|
||||||
def __init__(self, file_path: Optional[str] = None):
|
def __init__(self, file_path: Optional[str] = None) -> None:
|
||||||
super().__init__(file_path, 1)
|
super().__init__(file_path, 1)
|
||||||
high = np.inf * np.ones(6, dtype=np.float32)
|
high = np.inf * np.ones(6, dtype=np.float32)
|
||||||
high[3:] = self.VELOCITY_LIMITS * 1.2
|
high[3:] = self.VELOCITY_LIMITS * 1.2
|
||||||
|
Loading…
Reference in New Issue
Block a user