Implement a goal base reward function
This commit is contained in:
parent
7c20df20d7
commit
38f87fbb2d
@ -3,11 +3,18 @@ import gym
|
|||||||
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
|
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_kwargs(maze_id: str) -> tuple:
|
||||||
|
return {
|
||||||
|
"observe_blocks": maze_id in ["Block", "BlockMaze"],
|
||||||
|
"pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
for maze_id in MAZE_IDS:
|
for maze_id in MAZE_IDS:
|
||||||
gym.envs.register(
|
gym.envs.register(
|
||||||
id="AntMaze{}-v0".format(maze_id),
|
id="AntMaze{}-v0".format(maze_id),
|
||||||
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
||||||
kwargs=dict(maze_id=maze_id, manual_collision=True),
|
kwargs=dict(maze_id=maze_id, maze_size_scaling=8, **_get_kwargs(maze_id)),
|
||||||
max_episode_steps=1000,
|
max_episode_steps=1000,
|
||||||
reward_threshold=-1000,
|
reward_threshold=-1000,
|
||||||
)
|
)
|
||||||
@ -16,7 +23,12 @@ for maze_id in MAZE_IDS:
|
|||||||
gym.envs.register(
|
gym.envs.register(
|
||||||
id="PointMaze{}-v0".format(maze_id),
|
id="PointMaze{}-v0".format(maze_id),
|
||||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
||||||
kwargs=dict(maze_id=maze_id, manual_collision=True),
|
kwargs=dict(
|
||||||
|
maze_id=maze_id,
|
||||||
|
maze_size_scaling=4,
|
||||||
|
manual_collision=True,
|
||||||
|
**_get_kwargs(maze_id),
|
||||||
|
),
|
||||||
max_episode_steps=1000,
|
max_episode_steps=1000,
|
||||||
reward_threshold=-1000,
|
reward_threshold=-1000,
|
||||||
)
|
)
|
||||||
|
@ -36,4 +36,3 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_ori(self) -> float:
|
def get_ori(self) -> float:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
from typing import Type
|
from typing import Callable, Type, Union
|
||||||
|
|
||||||
from mujoco_maze.agent_model import AgentModel
|
from mujoco_maze.agent_model import AgentModel
|
||||||
from mujoco_maze import maze_env_utils
|
from mujoco_maze import maze_env_utils
|
||||||
@ -49,6 +49,8 @@ class MazeEnv(gym.Env):
|
|||||||
put_spin_near_agent=False,
|
put_spin_near_agent=False,
|
||||||
top_down_view=False,
|
top_down_view=False,
|
||||||
manual_collision=False,
|
manual_collision=False,
|
||||||
|
dense_reward=True,
|
||||||
|
goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default",
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -162,7 +164,7 @@ class MazeEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
elif maze_env_utils.can_move(struct): # Movable block.
|
elif maze_env_utils.can_move(struct): # Movable block.
|
||||||
# The "falling" blocks are shrunk slightly and increased in mass to
|
# The "falling" blocks are shrunk slightly and increased in mass to
|
||||||
# ensure that it can fall easily through a gap in the platform blocks.
|
# ensure it can fall easily through a gap in the platform blocks.
|
||||||
name = "movable_%d_%d" % (i, j)
|
name = "movable_%d_%d" % (i, j)
|
||||||
self.movable_blocks.append((name, struct))
|
self.movable_blocks.append((name, struct))
|
||||||
falling = maze_env_utils.can_move_z(struct)
|
falling = maze_env_utils.can_move_z(struct)
|
||||||
@ -265,6 +267,29 @@ class MazeEnv(gym.Env):
|
|||||||
tree.write(file_path)
|
tree.write(file_path)
|
||||||
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
|
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
|
||||||
|
|
||||||
|
# Set reward function
|
||||||
|
self._reward_fn = _reward_fn(maze_id, dense_reward)
|
||||||
|
|
||||||
|
# Set goal sampler
|
||||||
|
if isinstance(goal_sampler, str):
|
||||||
|
if goal_sampler == "random":
|
||||||
|
self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20))
|
||||||
|
elif goal_sampler == "default":
|
||||||
|
default_goal = _default_goal(maze_id)
|
||||||
|
self._goal_sampler = lambda: default_goal
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}")
|
||||||
|
elif isinstance(goal_sampler, np.ndarray):
|
||||||
|
self._goal_sampler = lambda: goal_sampler
|
||||||
|
elif callable(goal_sampler):
|
||||||
|
self._goal_sampler = goal_sampler
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid goal_sampler: {goal_sampler}")
|
||||||
|
self.goal = self._goal_sampler()
|
||||||
|
|
||||||
|
# Set goal function
|
||||||
|
self._goal_fn = _goal_fn(maze_id)
|
||||||
|
|
||||||
def get_ori(self):
|
def get_ori(self):
|
||||||
return self.wrapped_env.get_ori()
|
return self.wrapped_env.get_ori()
|
||||||
|
|
||||||
@ -472,6 +497,8 @@ class MazeEnv(gym.Env):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self.t = 0
|
self.t = 0
|
||||||
self.wrapped_env.reset()
|
self.wrapped_env.reset()
|
||||||
|
# Sample a new goal
|
||||||
|
self.goal = self._goal_sampler()
|
||||||
if len(self._init_positions) > 1:
|
if len(self._init_positions) > 1:
|
||||||
xy = np.random.choice(self._init_positions)
|
xy = np.random.choice(self._init_positions)
|
||||||
self.wrapped_env.set_xy(xy)
|
self.wrapped_env.set_xy(xy)
|
||||||
@ -529,15 +556,57 @@ class MazeEnv(gym.Env):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _is_in_goal(self, pos):
|
||||||
|
(np.linalg.norm(obs[:3] - goal) <= 0.6)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.t += 1
|
self.t += 1
|
||||||
if self._manual_collision:
|
if self._manual_collision:
|
||||||
old_pos = self.wrapped_env.get_xy()
|
old_pos = self.wrapped_env.get_xy()
|
||||||
inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
|
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||||
new_pos = self.wrapped_env.get_xy()
|
new_pos = self.wrapped_env.get_xy()
|
||||||
if self._is_in_collision(new_pos):
|
if self._is_in_collision(new_pos):
|
||||||
self.wrapped_env.set_xy(old_pos)
|
self.wrapped_env.set_xy(old_pos)
|
||||||
else:
|
else:
|
||||||
inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
|
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||||
next_obs = self._get_obs()
|
next_obs = self._get_obs()
|
||||||
return next_obs, inner_reward, False, info
|
outer_reward = self._reward_fn(next_obs, self.goal)
|
||||||
|
done = self._goal_fn(next_obs, self.goal)
|
||||||
|
return next_obs, inner_reward + outer_reward, done, info
|
||||||
|
|
||||||
|
|
||||||
|
def _goal_fn(maze_id: str) -> callable:
|
||||||
|
if maze_id in ["Maze", "Push"]:
|
||||||
|
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
|
||||||
|
elif maze_id == "Fall":
|
||||||
|
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _reward_fn(maze_id: str, dense: str) -> callable:
|
||||||
|
if dense:
|
||||||
|
if maze_id in ["Maze", "Push"]:
|
||||||
|
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
|
||||||
|
elif maze_id == "Fall":
|
||||||
|
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||||
|
else:
|
||||||
|
if maze_id in ["Maze", "Push"]:
|
||||||
|
return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
|
||||||
|
elif maze_id == "Fall":
|
||||||
|
return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _default_goal(maze_id: str) -> np.ndarray:
|
||||||
|
if maze_id == "Maze":
|
||||||
|
return np.array([0.0, 8.0])
|
||||||
|
elif maze_id == "Push":
|
||||||
|
return np.array([0.0, 19.0])
|
||||||
|
elif maze_id == "Fall":
|
||||||
|
return np.array([0.0, 27.0, 4.5])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||||
|
Loading…
Reference in New Issue
Block a user