Make some configurations class attributes

This commit is contained in:
kngwyu 2020-06-30 22:42:22 +09:00
parent bbbe0f38e3
commit d5cc345080
3 changed files with 13 additions and 14 deletions

View File

@ -3,20 +3,13 @@ import gym
from mujoco_maze.maze_task import TaskRegistry
def _get_kwargs(maze_id: str) -> tuple:
return {
"maze_id": maze_id,
"observe_blocks": maze_id in ["Block", "BlockMaze"],
"put_spin_near_agent": maze_id in ["Block", "BlockMaze"],
}
for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
scaling = task_cls.SCALING.ant
gym.envs.register(
id=f"Ant{maze_id}-v{i}",
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
kwargs=dict(maze_task=task_cls, maze_size_scaling=8.0),
kwargs=dict(maze_task=task_cls, maze_size_scaling=scaling),
max_episode_steps=1000,
reward_threshold=task_cls.REWARD_THRESHOLD,
)
@ -26,7 +19,7 @@ for maze_id in TaskRegistry.keys():
gym.envs.register(
id=f"Point{maze_id}-v{i}",
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
kwargs=dict(maze_task=task_cls),
kwargs=dict(maze_task=task_cls, maze_size_scaling=scaling),
max_episode_steps=1000,
reward_threshold=task_cls.REWARD_THRESHOLD,
)

View File

@ -41,8 +41,6 @@ class MazeEnv(gym.Env):
n_bins: int = 0,
sensor_range: float = 3.0,
sensor_span: float = 2 * np.pi,
observe_blocks: float = False,
put_spin_near_agent: float = False,
top_down_view: float = False,
maze_height: float = 0.5,
maze_size_scaling: float = 4.0,
@ -61,8 +59,8 @@ class MazeEnv(gym.Env):
self._n_bins = n_bins
self._sensor_range = sensor_range * size_scaling
self._sensor_span = sensor_span
self._observe_blocks = observe_blocks
self._put_spin_near_agent = put_spin_near_agent
self._observe_blocks = self._task.OBSERVE_BLOCKS
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
self._top_down_view = top_down_view
self._collision_coef = 0.1

View File

@ -40,8 +40,16 @@ class MazeGoal:
return np.sum(np.square(obs[: self.dim] - self.pos)) ** 0.5
class Scaling(NamedTuple):
ant: float
point: float
class MazeTask(ABC):
REWARD_THRESHOLD: float
SCALING: Scaling = Scaling(8.0, 4.0)
OBSERVE_BLOCKS: bool = False
PUT_SPIN_NEAR_AGENT: bool = False
def __init__(self, scale: float) -> None:
self.scale = scale