Introduce MazeTask for customizability
This commit is contained in:
parent
c91a4bc8a7
commit
d08cfe5d0e
@ -1,5 +1,8 @@
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
from mujoco_maze.maze_task import TaskRegistry
|
||||||
|
|
||||||
|
|
||||||
MAZE_IDS = ["Maze", "Push", "Fall"] # TODO: Block, BlockMaze
|
MAZE_IDS = ["Maze", "Push", "Fall"] # TODO: Block, BlockMaze
|
||||||
|
|
||||||
|
|
||||||
@ -12,36 +15,24 @@ def _get_kwargs(maze_id: str) -> tuple:
|
|||||||
|
|
||||||
|
|
||||||
for maze_id in MAZE_IDS:
|
for maze_id in MAZE_IDS:
|
||||||
gym.envs.register(
|
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
|
||||||
id="Ant{}-v0".format(maze_id),
|
gym.envs.register(
|
||||||
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
id=f"Ant{maze_id}-v{i}",
|
||||||
kwargs=dict(maze_size_scaling=8.0, **_get_kwargs(maze_id)),
|
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
||||||
max_episode_steps=1000,
|
kwargs=dict(maze_task=task_cls, maze_size_scaling=8.0),
|
||||||
reward_threshold=-1000,
|
max_episode_steps=1000,
|
||||||
)
|
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||||
gym.envs.register(
|
)
|
||||||
id="Ant{}-v1".format(maze_id),
|
|
||||||
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
|
||||||
kwargs=dict(maze_size_scaling=8.0, **_get_kwargs(maze_id)),
|
|
||||||
max_episode_steps=1000,
|
|
||||||
reward_threshold=0.9,
|
|
||||||
)
|
|
||||||
|
|
||||||
for maze_id in MAZE_IDS:
|
for maze_id in MAZE_IDS:
|
||||||
gym.envs.register(
|
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
|
||||||
id="Point{}-v0".format(maze_id),
|
gym.envs.register(
|
||||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
id=f"Point{maze_id}-v{i}",
|
||||||
kwargs=_get_kwargs(maze_id),
|
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
||||||
max_episode_steps=1000,
|
kwargs=dict(maze_task=task_cls),
|
||||||
reward_threshold=-1000,
|
max_episode_steps=1000,
|
||||||
)
|
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||||
gym.envs.register(
|
)
|
||||||
id="Point{}-v1".format(maze_id),
|
|
||||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
|
||||||
kwargs=dict(**_get_kwargs(maze_id), dense_reward=False),
|
|
||||||
max_episode_steps=1000,
|
|
||||||
reward_threshold=0.9,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
@ -16,17 +16,17 @@
|
|||||||
"""Adapted from rllab maze_env.py."""
|
"""Adapted from rllab maze_env.py."""
|
||||||
|
|
||||||
import itertools as it
|
import itertools as it
|
||||||
import math
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
from typing import Callable, Type, Union
|
from typing import Type
|
||||||
|
|
||||||
from mujoco_maze.agent_model import AgentModel
|
from mujoco_maze.agent_model import AgentModel
|
||||||
from mujoco_maze import maze_env_utils
|
from mujoco_maze import maze_env_utils
|
||||||
|
from mujoco_maze import maze_task
|
||||||
|
|
||||||
# Directory that contains mujoco xml files.
|
# Directory that contains mujoco xml files.
|
||||||
MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
|
MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
|
||||||
@ -36,26 +36,23 @@ class MazeEnv(gym.Env):
|
|||||||
MODEL_CLASS: Type[AgentModel] = AgentModel
|
MODEL_CLASS: Type[AgentModel] = AgentModel
|
||||||
|
|
||||||
MANUAL_COLLISION: bool = False
|
MANUAL_COLLISION: bool = False
|
||||||
# For preventing the point from going through the wall
|
BLOCK_EPS: float = 0.0001
|
||||||
SIZE_EPS = 0.0001
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
maze_id=None,
|
maze_task: Type[maze_task.MazeTask] = maze_task.SingleGoalSparseEMaze(),
|
||||||
n_bins=0,
|
n_bins: int = 0,
|
||||||
sensor_range=3.0,
|
sensor_range: float = 3.0,
|
||||||
sensor_span=2 * math.pi,
|
sensor_span: float = 2 * np.pi,
|
||||||
observe_blocks=False,
|
observe_blocks: float = False,
|
||||||
put_spin_near_agent=False,
|
put_spin_near_agent: float = False,
|
||||||
top_down_view=False,
|
top_down_view: float = False,
|
||||||
dense_reward=True,
|
|
||||||
maze_height: float = 0.5,
|
maze_height: float = 0.5,
|
||||||
maze_size_scaling: float = 4.0,
|
maze_size_scaling: float = 4.0,
|
||||||
goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default",
|
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._maze_id = maze_id
|
self._task = maze_task()
|
||||||
|
|
||||||
xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE)
|
xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE)
|
||||||
tree = ET.parse(xml_path)
|
tree = ET.parse(xml_path)
|
||||||
@ -72,15 +69,11 @@ class MazeEnv(gym.Env):
|
|||||||
self._top_down_view = top_down_view
|
self._top_down_view = top_down_view
|
||||||
self._collision_coef = 0.1
|
self._collision_coef = 0.1
|
||||||
|
|
||||||
self._maze_structure = structure = maze_env_utils.construct_maze(
|
self._maze_structure = structure = self._task.create_maze()
|
||||||
maze_id=self._maze_id
|
|
||||||
)
|
|
||||||
# Elevate the maze to allow for falling.
|
# Elevate the maze to allow for falling.
|
||||||
self.elevated = any(maze_env_utils.MazeCell.CHASM in row for row in structure)
|
self.elevated = any(maze_env_utils.MazeCell.CHASM in row for row in structure)
|
||||||
# Are there any movable blocks?
|
# Are there any movable blocks?
|
||||||
self.blocks = any(
|
self.blocks = any(any(r.can_move() for r in row) for row in structure)
|
||||||
any(r.can_move() for r in row) for row in structure
|
|
||||||
)
|
|
||||||
|
|
||||||
torso_x, torso_y = self._find_robot()
|
torso_x, torso_y = self._find_robot()
|
||||||
self._init_torso_x = torso_x
|
self._init_torso_x = torso_x
|
||||||
@ -117,13 +110,13 @@ class MazeEnv(gym.Env):
|
|||||||
for j in range(len(structure[0])):
|
for j in range(len(structure[0])):
|
||||||
struct = structure[i][j]
|
struct = structure[i][j]
|
||||||
if struct.is_robot() and self._put_spin_near_agent:
|
if struct.is_robot() and self._put_spin_near_agent:
|
||||||
struct = maze_env_utils.Move.SpinXY
|
struct = maze_env_utils.MazeCell.SpinXY
|
||||||
if self.elevated and not struct.is_chasm():
|
if self.elevated and not struct.is_chasm():
|
||||||
# Create elevated platform.
|
# Create elevated platform.
|
||||||
x = j * size_scaling - torso_x
|
x = j * size_scaling - torso_x
|
||||||
y = i * size_scaling - torso_y
|
y = i * size_scaling - torso_y
|
||||||
h = height / 2 * size_scaling
|
h = height / 2 * size_scaling
|
||||||
size = 0.5 * size_scaling + self.SIZE_EPS
|
size = 0.5 * size_scaling + self.BLOCK_EPS
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
worldbody,
|
worldbody,
|
||||||
"geom",
|
"geom",
|
||||||
@ -142,7 +135,7 @@ class MazeEnv(gym.Env):
|
|||||||
x = j * size_scaling - torso_x
|
x = j * size_scaling - torso_x
|
||||||
y = i * size_scaling - torso_y
|
y = i * size_scaling - torso_y
|
||||||
h = height / 2 * size_scaling
|
h = height / 2 * size_scaling
|
||||||
size = 0.5 * size_scaling + self.SIZE_EPS
|
size = 0.5 * size_scaling + self.BLOCK_EPS
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
worldbody,
|
worldbody,
|
||||||
"geom",
|
"geom",
|
||||||
@ -172,7 +165,7 @@ class MazeEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
y = i * size_scaling - torso_y
|
y = i * size_scaling - torso_y
|
||||||
h = height / 2 * size_scaling * height_shrink
|
h = height / 2 * size_scaling * height_shrink
|
||||||
size = 0.5 * size_scaling * shrink + self.SIZE_EPS
|
size = 0.5 * size_scaling * shrink + self.BLOCK_EPS
|
||||||
movable_body = ET.SubElement(
|
movable_body = ET.SubElement(
|
||||||
worldbody,
|
worldbody,
|
||||||
"body",
|
"body",
|
||||||
@ -257,29 +250,6 @@ class MazeEnv(gym.Env):
|
|||||||
tree.write(file_path)
|
tree.write(file_path)
|
||||||
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
|
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
|
||||||
|
|
||||||
# Set reward function
|
|
||||||
self._reward_fn = _reward_fn(maze_id, dense_reward)
|
|
||||||
|
|
||||||
# Set goal sampler
|
|
||||||
if isinstance(goal_sampler, str):
|
|
||||||
if goal_sampler == "random":
|
|
||||||
self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20))
|
|
||||||
elif goal_sampler == "default":
|
|
||||||
default_goal = _default_goal(maze_id, size_scaling)
|
|
||||||
self._goal_sampler = lambda: default_goal
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}")
|
|
||||||
elif isinstance(goal_sampler, np.ndarray):
|
|
||||||
self._goal_sampler = lambda: goal_sampler
|
|
||||||
elif callable(goal_sampler):
|
|
||||||
self._goal_sampler = goal_sampler
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid goal_sampler: {goal_sampler}")
|
|
||||||
self.goal = self._goal_sampler()
|
|
||||||
|
|
||||||
# Set goal function
|
|
||||||
self._goal_fn = _goal_fn(maze_id)
|
|
||||||
|
|
||||||
def get_ori(self):
|
def get_ori(self):
|
||||||
return self.wrapped_env.get_ori()
|
return self.wrapped_env.get_ori()
|
||||||
|
|
||||||
@ -488,7 +458,7 @@ class MazeEnv(gym.Env):
|
|||||||
self.t = 0
|
self.t = 0
|
||||||
self.wrapped_env.reset()
|
self.wrapped_env.reset()
|
||||||
# Sample a new goal
|
# Sample a new goal
|
||||||
self.goal = self._goal_sampler()
|
self._task.sample_goals(self._maze_size_scaling)
|
||||||
if len(self._init_positions) > 1:
|
if len(self._init_positions) > 1:
|
||||||
xy = np.random.choice(self._init_positions)
|
xy = np.random.choice(self._init_positions)
|
||||||
self.wrapped_env.set_xy(xy)
|
self.wrapped_env.set_xy(xy)
|
||||||
@ -540,51 +510,6 @@ class MazeEnv(gym.Env):
|
|||||||
else:
|
else:
|
||||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||||
next_obs = self._get_obs()
|
next_obs = self._get_obs()
|
||||||
outer_reward = self._reward_fn(next_obs, self.goal)
|
outer_reward = self._task.reward(next_obs)
|
||||||
done = self._goal_fn(next_obs, self.goal)
|
done = self._task.termination(next_obs)
|
||||||
return next_obs, inner_reward + outer_reward, done, info
|
return next_obs, inner_reward + outer_reward, done, info
|
||||||
|
|
||||||
|
|
||||||
def _goal_fn(maze_id: str) -> callable:
|
|
||||||
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
|
||||||
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
|
|
||||||
elif maze_id == "Fall":
|
|
||||||
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
|
||||||
|
|
||||||
|
|
||||||
def _reward_fn(maze_id: str, dense: str) -> callable:
|
|
||||||
if dense:
|
|
||||||
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
|
||||||
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
|
|
||||||
elif maze_id == "Fall":
|
|
||||||
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
|
||||||
else:
|
|
||||||
if maze_id in ["Maze", "Push", "BlockMaze"]:
|
|
||||||
return (
|
|
||||||
lambda obs, goal: 1.0
|
|
||||||
if np.linalg.norm(obs[:2] - goal) <= 0.6
|
|
||||||
else -0.0001
|
|
||||||
)
|
|
||||||
elif maze_id == "Fall":
|
|
||||||
return (
|
|
||||||
lambda obs, goal: 1.0
|
|
||||||
if np.linalg.norm(obs[:3] - goal) <= 0.6
|
|
||||||
else -0.0001
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
|
||||||
|
|
||||||
|
|
||||||
def _default_goal(maze_id: str, scale: float) -> np.ndarray:
|
|
||||||
if maze_id == "Maze" or maze_id == "BlockMaze":
|
|
||||||
return np.array([0.0, 2.0 * scale])
|
|
||||||
elif maze_id == "Push":
|
|
||||||
return np.array([0.0, 2.375 * scale])
|
|
||||||
elif maze_id == "Fall":
|
|
||||||
return np.array([0.0, 3.375 * scale, 4.5])
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
|
||||||
|
@ -77,55 +77,6 @@ class MazeCell(Enum):
|
|||||||
return self.can_move_x() or self.can_move_y() or self.can_move_z()
|
return self.can_move_x() or self.can_move_y() or self.can_move_z()
|
||||||
|
|
||||||
|
|
||||||
def construct_maze(maze_id="Maze"):
|
|
||||||
E, B, C, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.CHASM, MazeCell.ROBOT
|
|
||||||
if maze_id == "Maze":
|
|
||||||
structure = [
|
|
||||||
[B, B, B, B, B],
|
|
||||||
[B, R, E, E, B],
|
|
||||||
[B, B, B, E, B],
|
|
||||||
[B, E, E, E, B],
|
|
||||||
[B, B, B, B, B],
|
|
||||||
]
|
|
||||||
elif maze_id == "Push":
|
|
||||||
structure = [
|
|
||||||
[B, B, B, B, B],
|
|
||||||
[B, E, R, B, B],
|
|
||||||
[B, E, MazeCell.XY, E, B],
|
|
||||||
[B, B, E, B, B],
|
|
||||||
[B, B, B, B, B],
|
|
||||||
]
|
|
||||||
elif maze_id == "Fall":
|
|
||||||
structure = [
|
|
||||||
[B, B, B, B],
|
|
||||||
[B, R, E, B],
|
|
||||||
[B, E, MazeCell.YZ, B],
|
|
||||||
[B, C, C, B],
|
|
||||||
[B, E, E, B],
|
|
||||||
[B, B, B, B],
|
|
||||||
]
|
|
||||||
elif maze_id == "Block":
|
|
||||||
structure = [
|
|
||||||
[B, B, B, B, B],
|
|
||||||
[B, R, E, E, B],
|
|
||||||
[B, E, E, E, B],
|
|
||||||
[B, E, E, E, B],
|
|
||||||
[B, B, B, B, B],
|
|
||||||
]
|
|
||||||
elif maze_id == "BlockMaze":
|
|
||||||
structure = [
|
|
||||||
[B, B, B, B],
|
|
||||||
[B, R, E, B],
|
|
||||||
[B, B, E, B],
|
|
||||||
[B, E, E, B],
|
|
||||||
[B, B, B, B],
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("The provided MazeId %s is not recognized" % maze_id)
|
|
||||||
|
|
||||||
return structure
|
|
||||||
|
|
||||||
|
|
||||||
class Collision:
|
class Collision:
|
||||||
"""For manual collision detection.
|
"""For manual collision detection.
|
||||||
"""
|
"""
|
||||||
|
137
mujoco_maze/maze_task.py
Normal file
137
mujoco_maze/maze_task.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Type
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mujoco_maze.maze_env_utils import MazeCell
|
||||||
|
|
||||||
|
|
||||||
|
class MazeGoal:
|
||||||
|
THRESHOLD: float = 0.6
|
||||||
|
|
||||||
|
def __init__(self, goal: np.ndarray, reward_scale: float = 1.0) -> None:
|
||||||
|
self.goal = goal
|
||||||
|
self.goal_dim = goal.shape[0]
|
||||||
|
self.reward_scale = reward_scale
|
||||||
|
|
||||||
|
def neighbor(self, obs: np.ndarray) -> float:
|
||||||
|
return np.linalg.norm(obs[: self.goal_dim] - self.goal) <= self.THRESHOLD
|
||||||
|
|
||||||
|
def euc_dist(self, obs: np.ndarray) -> float:
|
||||||
|
return np.sum(np.square(obs[: self.goal_dim] - self.goal)) ** 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class MazeTask(ABC):
|
||||||
|
REWARD_THRESHOLD: float
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.goals = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sample_goals(self, scale: float) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def termination(self, obs: np.ndarray) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalSparseEMaze(MazeTask):
|
||||||
|
REWARD_THRESHOLD: float = 0.9
|
||||||
|
|
||||||
|
def sample_goals(self, scale: float) -> None:
|
||||||
|
goal = MazeGoal(np.array([0.0, 2.0 * scale]))
|
||||||
|
self.goals = [goal]
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
if self.goals[0].neighbor(obs):
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return -0.0001
|
||||||
|
|
||||||
|
def termination(self, obs: np.ndarray) -> bool:
|
||||||
|
return self.goals[0].neighbor(obs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
|
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
|
||||||
|
return [
|
||||||
|
[B, B, B, B, B],
|
||||||
|
[B, R, E, E, B],
|
||||||
|
[B, B, B, E, B],
|
||||||
|
[B, E, E, E, B],
|
||||||
|
[B, B, B, B, B],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalDenseEMaze(SingleGoalSparseEMaze):
|
||||||
|
REWARD_THRESHOLD: float = 1000.0
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
return -self.goals[0].euc_dist(obs)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalSparsePush(SingleGoalSparseEMaze):
|
||||||
|
def sample_goals(self, scale: float) -> None:
|
||||||
|
goal = MazeGoal(np.array([0.0, 2.375 * scale]))
|
||||||
|
self.goals = [goal]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
|
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
|
||||||
|
return [
|
||||||
|
[B, B, B, B, B],
|
||||||
|
[B, E, R, B, B],
|
||||||
|
[B, E, MazeCell.XY, E, B],
|
||||||
|
[B, B, E, B, B],
|
||||||
|
[B, B, B, B, B],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalDensePush(SingleGoalSparsePush):
|
||||||
|
REWARD_THRESHOLD: float = 1000.0
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
return -self.goals[0].euc_dist(obs)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalSparseFall(SingleGoalSparseEMaze):
|
||||||
|
def sample_goals(self, scale: float) -> None:
|
||||||
|
goal = MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))
|
||||||
|
self.goals = [goal]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
|
E, B, C, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.CHASM, MazeCell.ROBOT
|
||||||
|
return [
|
||||||
|
[B, B, B, B],
|
||||||
|
[B, R, E, B],
|
||||||
|
[B, E, MazeCell.YZ, B],
|
||||||
|
[B, C, C, B],
|
||||||
|
[B, E, E, B],
|
||||||
|
[B, B, B, B],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalDenseFall(SingleGoalSparseFall):
|
||||||
|
REWARD_THRESHOLD: float = 1000.0
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
return -self.goals[0].euc_dist(obs)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRegistry:
|
||||||
|
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
|
||||||
|
"Maze": [SingleGoalDenseEMaze, SingleGoalSparseEMaze],
|
||||||
|
"Push": [SingleGoalDensePush, SingleGoalSparsePush],
|
||||||
|
"Fall": [SingleGoalDenseFall, SingleGoalSparseFall],
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user