Make some configurations class attributes
This commit is contained in:
parent
bbbe0f38e3
commit
d5cc345080
@ -3,20 +3,13 @@ import gym
|
||||
from mujoco_maze.maze_task import TaskRegistry
|
||||
|
||||
|
||||
def _get_kwargs(maze_id: str) -> tuple:
|
||||
return {
|
||||
"maze_id": maze_id,
|
||||
"observe_blocks": maze_id in ["Block", "BlockMaze"],
|
||||
"put_spin_near_agent": maze_id in ["Block", "BlockMaze"],
|
||||
}
|
||||
|
||||
|
||||
for maze_id in TaskRegistry.keys():
|
||||
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
||||
scaling = task_cls.SCALING.ant
|
||||
gym.envs.register(
|
||||
id=f"Ant{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
||||
kwargs=dict(maze_task=task_cls, maze_size_scaling=8.0),
|
||||
kwargs=dict(maze_task=task_cls, maze_size_scaling=scaling),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
@ -26,7 +19,7 @@ for maze_id in TaskRegistry.keys():
|
||||
gym.envs.register(
|
||||
id=f"Point{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
||||
kwargs=dict(maze_task=task_cls),
|
||||
kwargs=dict(maze_task=task_cls, maze_size_scaling=scaling),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
|
@ -41,8 +41,6 @@ class MazeEnv(gym.Env):
|
||||
n_bins: int = 0,
|
||||
sensor_range: float = 3.0,
|
||||
sensor_span: float = 2 * np.pi,
|
||||
observe_blocks: float = False,
|
||||
put_spin_near_agent: float = False,
|
||||
top_down_view: float = False,
|
||||
maze_height: float = 0.5,
|
||||
maze_size_scaling: float = 4.0,
|
||||
@ -61,8 +59,8 @@ class MazeEnv(gym.Env):
|
||||
self._n_bins = n_bins
|
||||
self._sensor_range = sensor_range * size_scaling
|
||||
self._sensor_span = sensor_span
|
||||
self._observe_blocks = observe_blocks
|
||||
self._put_spin_near_agent = put_spin_near_agent
|
||||
self._observe_blocks = self._task.OBSERVE_BLOCKS
|
||||
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
||||
self._top_down_view = top_down_view
|
||||
self._collision_coef = 0.1
|
||||
|
||||
|
@ -40,8 +40,16 @@ class MazeGoal:
|
||||
return np.sum(np.square(obs[: self.dim] - self.pos)) ** 0.5
|
||||
|
||||
|
||||
class Scaling(NamedTuple):
|
||||
ant: float
|
||||
point: float
|
||||
|
||||
|
||||
class MazeTask(ABC):
|
||||
REWARD_THRESHOLD: float
|
||||
SCALING: Scaling = Scaling(8.0, 4.0)
|
||||
OBSERVE_BLOCKS: bool = False
|
||||
PUT_SPIN_NEAR_AGENT: bool = False
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
self.scale = scale
|
||||
|
Loading…
Reference in New Issue
Block a user