Formatting
This commit is contained in:
parent
e1af62cdc7
commit
4087203b06
@ -27,16 +27,13 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
|
||||
|
||||
@abstractmethod
|
||||
def _get_obs(self) -> np.ndarray:
|
||||
"""Returns the observation from the model.
|
||||
"""
|
||||
"""Returns the observation from the model."""
|
||||
pass
|
||||
|
||||
def get_xy(self) -> np.ndarray:
|
||||
"""Returns the coordinate of the agent.
|
||||
"""
|
||||
"""Returns the coordinate of the agent."""
|
||||
pass
|
||||
|
||||
def set_xy(self, xy: np.ndarray) -> None:
|
||||
"""Set the coordinate of the agent.
|
||||
"""
|
||||
"""Set the coordinate of the agent."""
|
||||
pass
|
||||
|
@ -82,7 +82,9 @@ class AntEnv(AgentModel):
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
size=self.model.nq, low=-0.1, high=0.1,
|
||||
size=self.model.nq,
|
||||
low=-0.1,
|
||||
high=0.1,
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||
|
||||
|
@ -370,17 +370,16 @@ class MazeEnv(gym.Env):
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
def render(self, mode="human", **kwargs) -> Any:
|
||||
if self._websock_port is not None:
|
||||
def render(self, mode="human", **kwargs) -> Optional[np.ndarray]:
|
||||
if mode == "human" and self._websock_port is not None:
|
||||
if self._mj_offscreen_viewer is None:
|
||||
from mujoco_py import MjRenderContextOffscreen as MjRenderOffscreen
|
||||
from mujoco_py import MjRenderContextOffscreen as MjRCO
|
||||
|
||||
from mujoco_maze.websock_viewer import start_server
|
||||
|
||||
self._mj_offscreen_viewer = MjRenderOffscreen(self.wrapped_env.sim)
|
||||
self._mj_offscreen_viewer = MjRCO(self.wrapped_env.sim)
|
||||
self._websock_server_pipe = start_server(self._websock_port)
|
||||
self._websock_server_pipe.send(self._render_image())
|
||||
return True
|
||||
else:
|
||||
return self.wrapped_env.render(mode, **kwargs)
|
||||
|
||||
|
@ -83,7 +83,9 @@ class MazeCell(Enum):
|
||||
|
||||
class Line:
|
||||
def __init__(
|
||||
self, p1: Union[Sequence[float], Point], p2: Union[Sequence[float], Point],
|
||||
self,
|
||||
p1: Union[Sequence[float], Point],
|
||||
p2: Union[Sequence[float], Point],
|
||||
) -> None:
|
||||
self.p1 = p1 if isinstance(p1, Point) else np.complex(*p1)
|
||||
self.p2 = p2 if isinstance(p2, Point) else np.complex(*p2)
|
||||
@ -141,8 +143,7 @@ class Collision:
|
||||
|
||||
|
||||
class CollisionDetector:
|
||||
"""For manual collision detection.
|
||||
"""
|
||||
"""For manual collision detection."""
|
||||
|
||||
EPS: float = 0.05
|
||||
NEIGHBORS: List[Tuple[int, int]] = [[0, -1], [-1, 0], [0, 1], [1, 0]]
|
||||
|
@ -55,10 +55,14 @@ class SwimmerEnv(AgentModel):
|
||||
|
||||
def reset_model(self) -> np.ndarray:
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=-0.1, high=0.1, size=self.model.nq,
|
||||
low=-0.1,
|
||||
high=0.1,
|
||||
size=self.model.nq,
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=-0.1, high=0.1, size=self.model.nv,
|
||||
low=-0.1,
|
||||
high=0.1,
|
||||
size=self.model.nv,
|
||||
)
|
||||
|
||||
self.set_state(qpos, qvel)
|
||||
|
@ -4,7 +4,6 @@ import multiprocessing as mp
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
|
||||
from PIL import Image
|
||||
|
||||
HTML = """
|
||||
|
Loading…
Reference in New Issue
Block a user