Make use of radius in collision detection

This commit is contained in:
kngwyu 2020-07-13 01:15:26 +09:00
parent a67db885a2
commit 8907d0a2c0
4 changed files with 42 additions and 17 deletions

View File

@ -49,6 +49,7 @@ class AntEnv(AgentModel):
) -> None: ) -> None:
self._ctrl_cost_weight = ctrl_cost_weight self._ctrl_cost_weight = ctrl_cost_weight
self._forward_reward_fn = forward_reward_fn self._forward_reward_fn = forward_reward_fn
self.radius = 0.3
super().__init__(file_path, 5) super().__init__(file_path, 5)
def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]: def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]:

View File

@ -31,7 +31,7 @@ class MazeEnv(gym.Env):
maze_height: float = 0.5, maze_height: float = 0.5,
maze_size_scaling: float = 4.0, maze_size_scaling: float = 4.0,
inner_reward_scaling: float = 1.0, inner_reward_scaling: float = 1.0,
restitution_coef: float = 0.9, restitution_coef: float = 0.8,
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -246,6 +246,7 @@ class MazeEnv(gym.Env):
self.world_tree = tree self.world_tree = tree
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
self.observation_space = self._get_obs_space() self.observation_space = self._get_obs_space()
self._debug = False
def get_ori(self) -> float: def get_ori(self) -> float:
return self.wrapped_env.get_ori() return self.wrapped_env.get_ori()
@ -435,16 +436,21 @@ class MazeEnv(gym.Env):
old_pos = self.wrapped_env.get_xy() old_pos = self.wrapped_env.get_xy()
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
new_pos = self.wrapped_env.get_xy() new_pos = self.wrapped_env.get_xy()
# Checks that new_position is in the wall # Checks that the new_position is in the wall
intersection = self._collision.detect_intersection(old_pos, new_pos) intersection = self._collision.detect_intersection(
old_pos, new_pos, self.wrapped_env.radius,
)
if intersection is not None: if intersection is not None:
pos = intersection + (intersection - new_pos) * self._restitution_coef pos = intersection + (intersection - new_pos) * self._restitution_coef
# Checks that pos is in the wall intersection2 = self._collision.detect_intersection(
intersection2 = self._collision.detect_intersection(old_pos, pos) old_pos, pos, self.wrapped_env.radius,
)
# If pos is also not in the wall, we give up computing the position
if intersection2 is not None: if intersection2 is not None:
# If pos is not in the wall, we give up computing the position
pos = old_pos pos = old_pos
self.wrapped_env.set_collision(pos, self._restitution_coef) self.wrapped_env.set_collision(pos, self._restitution_coef)
if self._debug:
print(f"new_pos: {new_pos}, pos: {pos}")
else: else:
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
next_obs = self._get_obs() next_obs = self._get_obs()

View File

@ -8,7 +8,7 @@ Based on `models`_ and `rllab`_.
import itertools as it import itertools as it
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Sequence, Tuple from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
@ -77,11 +77,25 @@ class MazeCell(Enum):
class Line: class Line:
def __init__(self, p1: Sequence[float], p2: Sequence[float]) -> None: def __init__(
self, p1: Union[Point, Sequence[float]], p2: Union[Point, Sequence[float]]
) -> None:
if isinstance(p1, Point):
self.p1 = p1
else:
self.p1 = np.complex(*p1) self.p1 = np.complex(*p1)
if isinstance(p2, Point):
self.p2 = p2
else:
self.p2 = np.complex(*p2) self.p2 = np.complex(*p2)
self.conj_v1 = np.conjugate(self.p2 - self.p1) self.conj_v1 = np.conjugate(self.p2 - self.p1)
def extend(self, dist: float) -> Tuple[Self, Point]:
v = self.p2 - self.p1
extended_v = v * dist / np.absolute(v)
p2 = self.p2 + extended_v
return Line(self.p1, p2), extended_v
def _intersect(self, other: Self) -> bool: def _intersect(self, other: Self) -> bool:
v2 = other.p1 - self.p1 v2 = other.p1 - self.p1
v3 = other.p2 - self.p1 v3 = other.p2 - self.p1
@ -135,15 +149,17 @@ class Collision:
for dx, dy in self.NEIGHBORS: for dx, dy in self.NEIGHBORS:
if not is_empty(i + dy, j + dx): if not is_empty(i + dy, j + dx):
continue continue
self.lines.append(Line( self.lines.append(
Line(
(max_x if dx == 1 else min_x, max_y if dy == 1 else min_y), (max_x if dx == 1 else min_x, max_y if dy == 1 else min_y),
(min_x if dx == -1 else max_x, min_y if dy == -1 else max_y), (min_x if dx == -1 else max_x, min_y if dy == -1 else max_y),
)) )
)
def detect_intersection( def detect_intersection(
self, old_pos: np.ndarray, new_pos: np.ndarray self, old_pos: np.ndarray, new_pos: np.ndarray, radius
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
move = Line(old_pos, new_pos) move, extended = Line(old_pos, new_pos).extend(radius)
intersections = [] intersections = []
for line in self.lines: for line in self.lines:
intersection = line.intersect(move) intersection = line.intersect(move)
@ -157,4 +173,4 @@ class Collision:
new_dist = np.linalg.norm(new_pos - old_pos) new_dist = np.linalg.norm(new_pos - old_pos)
if new_dist < dist: if new_dist < dist:
pos, dist = new_pos, new_dist pos, dist = new_pos, new_dist
return pos return pos - np.array([extended.real, extended.imag])

View File

@ -19,6 +19,7 @@ class PointEnv(AgentModel):
FILE: str = "point.xml" FILE: str = "point.xml"
ORI_IND: int = 2 ORI_IND: int = 2
MANUAL_COLLISION: bool = True MANUAL_COLLISION: bool = True
radius: float = 0.5
VELOCITY_LIMITS: float = 10.0 VELOCITY_LIMITS: float = 10.0
@ -28,6 +29,7 @@ class PointEnv(AgentModel):
high[3:] = self.VELOCITY_LIMITS * 1.2 high[3:] = self.VELOCITY_LIMITS * 1.2
high[self.ORI_IND] = np.pi high[self.ORI_IND] = np.pi
low = -high low = -high
self.radius = 0.5
self.observation_space = gym.spaces.Box(low, high) self.observation_space = gym.spaces.Box(low, high)
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: