Formatting
This commit is contained in:
parent
e1af62cdc7
commit
4087203b06
@ -27,16 +27,13 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_obs(self) -> np.ndarray:
|
def _get_obs(self) -> np.ndarray:
|
||||||
"""Returns the observation from the model.
|
"""Returns the observation from the model."""
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_xy(self) -> np.ndarray:
|
def get_xy(self) -> np.ndarray:
|
||||||
"""Returns the coordinate of the agent.
|
"""Returns the coordinate of the agent."""
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_xy(self, xy: np.ndarray) -> None:
|
def set_xy(self, xy: np.ndarray) -> None:
|
||||||
"""Set the coordinate of the agent.
|
"""Set the coordinate of the agent."""
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
@ -82,7 +82,9 @@ class AntEnv(AgentModel):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(
|
||||||
size=self.model.nq, low=-0.1, high=0.1,
|
size=self.model.nq,
|
||||||
|
low=-0.1,
|
||||||
|
high=0.1,
|
||||||
)
|
)
|
||||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||||
|
|
||||||
|
@ -370,17 +370,16 @@ class MazeEnv(gym.Env):
|
|||||||
dtype=np.uint8,
|
dtype=np.uint8,
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(self, mode="human", **kwargs) -> Any:
|
def render(self, mode="human", **kwargs) -> Optional[np.ndarray]:
|
||||||
if self._websock_port is not None:
|
if mode == "human" and self._websock_port is not None:
|
||||||
if self._mj_offscreen_viewer is None:
|
if self._mj_offscreen_viewer is None:
|
||||||
from mujoco_py import MjRenderContextOffscreen as MjRenderOffscreen
|
from mujoco_py import MjRenderContextOffscreen as MjRCO
|
||||||
|
|
||||||
from mujoco_maze.websock_viewer import start_server
|
from mujoco_maze.websock_viewer import start_server
|
||||||
|
|
||||||
self._mj_offscreen_viewer = MjRenderOffscreen(self.wrapped_env.sim)
|
self._mj_offscreen_viewer = MjRCO(self.wrapped_env.sim)
|
||||||
self._websock_server_pipe = start_server(self._websock_port)
|
self._websock_server_pipe = start_server(self._websock_port)
|
||||||
self._websock_server_pipe.send(self._render_image())
|
self._websock_server_pipe.send(self._render_image())
|
||||||
return True
|
|
||||||
else:
|
else:
|
||||||
return self.wrapped_env.render(mode, **kwargs)
|
return self.wrapped_env.render(mode, **kwargs)
|
||||||
|
|
||||||
|
@ -83,7 +83,9 @@ class MazeCell(Enum):
|
|||||||
|
|
||||||
class Line:
|
class Line:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, p1: Union[Sequence[float], Point], p2: Union[Sequence[float], Point],
|
self,
|
||||||
|
p1: Union[Sequence[float], Point],
|
||||||
|
p2: Union[Sequence[float], Point],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.p1 = p1 if isinstance(p1, Point) else np.complex(*p1)
|
self.p1 = p1 if isinstance(p1, Point) else np.complex(*p1)
|
||||||
self.p2 = p2 if isinstance(p2, Point) else np.complex(*p2)
|
self.p2 = p2 if isinstance(p2, Point) else np.complex(*p2)
|
||||||
@ -141,8 +143,7 @@ class Collision:
|
|||||||
|
|
||||||
|
|
||||||
class CollisionDetector:
|
class CollisionDetector:
|
||||||
"""For manual collision detection.
|
"""For manual collision detection."""
|
||||||
"""
|
|
||||||
|
|
||||||
EPS: float = 0.05
|
EPS: float = 0.05
|
||||||
NEIGHBORS: List[Tuple[int, int]] = [[0, -1], [-1, 0], [0, 1], [1, 0]]
|
NEIGHBORS: List[Tuple[int, int]] = [[0, -1], [-1, 0], [0, 1], [1, 0]]
|
||||||
|
@ -55,10 +55,14 @@ class SwimmerEnv(AgentModel):
|
|||||||
|
|
||||||
def reset_model(self) -> np.ndarray:
|
def reset_model(self) -> np.ndarray:
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(
|
||||||
low=-0.1, high=0.1, size=self.model.nq,
|
low=-0.1,
|
||||||
|
high=0.1,
|
||||||
|
size=self.model.nq,
|
||||||
)
|
)
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
qvel = self.init_qvel + self.np_random.uniform(
|
||||||
low=-0.1, high=0.1, size=self.model.nv,
|
low=-0.1,
|
||||||
|
high=0.1,
|
||||||
|
size=self.model.nv,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
@ -4,7 +4,6 @@ import multiprocessing as mp
|
|||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
HTML = """
|
HTML = """
|
||||||
|
Loading…
Reference in New Issue
Block a user