Implement a goal base reward function

This commit is contained in:
kngwyu 2020-05-29 23:54:03 +09:00
parent 7c20df20d7
commit 38f87fbb2d
4 changed files with 89 additions and 9 deletions

View File

@ -3,11 +3,18 @@ import gym
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
def _get_kwargs(maze_id: str) -> tuple:
return {
"observe_blocks": maze_id in ["Block", "BlockMaze"],
"pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
}
for maze_id in MAZE_IDS:
gym.envs.register(
id="AntMaze{}-v0".format(maze_id),
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
kwargs=dict(maze_id=maze_id, manual_collision=True),
kwargs=dict(maze_id=maze_id, maze_size_scaling=8, **_get_kwargs(maze_id)),
max_episode_steps=1000,
reward_threshold=-1000,
)
@ -16,7 +23,12 @@ for maze_id in MAZE_IDS:
gym.envs.register(
id="PointMaze{}-v0".format(maze_id),
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
kwargs=dict(maze_id=maze_id, manual_collision=True),
kwargs=dict(
maze_id=maze_id,
maze_size_scaling=4,
manual_collision=True,
**_get_kwargs(maze_id),
),
max_episode_steps=1000,
reward_threshold=-1000,
)

View File

@ -36,4 +36,3 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
@abstractmethod
def get_ori(self) -> float:
pass

View File

@ -126,7 +126,7 @@ class AntEnv(AgentModel):
def get_ori(self):
ori = [0, 1, 0, 0]
ori_ind = self.ORI_IND
rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion
rot = self.sim.data.qpos[ori_ind : ori_ind + 4] # take the quaternion
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
ori = math.atan2(ori[1], ori[0])
return ori

View File

@ -22,7 +22,7 @@ import math
import numpy as np
import gym
from typing import Type
from typing import Callable, Type, Union
from mujoco_maze.agent_model import AgentModel
from mujoco_maze import maze_env_utils
@ -49,6 +49,8 @@ class MazeEnv(gym.Env):
put_spin_near_agent=False,
top_down_view=False,
manual_collision=False,
dense_reward=True,
goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default",
*args,
**kwargs,
):
@ -162,7 +164,7 @@ class MazeEnv(gym.Env):
)
elif maze_env_utils.can_move(struct): # Movable block.
# The "falling" blocks are shrunk slightly and increased in mass to
# ensure that it can fall easily through a gap in the platform blocks.
# ensure it can fall easily through a gap in the platform blocks.
name = "movable_%d_%d" % (i, j)
self.movable_blocks.append((name, struct))
falling = maze_env_utils.can_move_z(struct)
@ -265,6 +267,29 @@ class MazeEnv(gym.Env):
tree.write(file_path)
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
# Set reward function
self._reward_fn = _reward_fn(maze_id, dense_reward)
# Set goal sampler
if isinstance(goal_sampler, str):
if goal_sampler == "random":
self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20))
elif goal_sampler == "default":
default_goal = _default_goal(maze_id)
self._goal_sampler = lambda: default_goal
else:
raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}")
elif isinstance(goal_sampler, np.ndarray):
self._goal_sampler = lambda: goal_sampler
elif callable(goal_sampler):
self._goal_sampler = goal_sampler
else:
raise ValueError(f"Invalid goal_sampler: {goal_sampler}")
self.goal = self._goal_sampler()
# Set goal function
self._goal_fn = _goal_fn(maze_id)
def get_ori(self):
return self.wrapped_env.get_ori()
@ -472,6 +497,8 @@ class MazeEnv(gym.Env):
def reset(self):
self.t = 0
self.wrapped_env.reset()
# Sample a new goal
self.goal = self._goal_sampler()
if len(self._init_positions) > 1:
xy = np.random.choice(self._init_positions)
self.wrapped_env.set_xy(xy)
@ -529,15 +556,57 @@ class MazeEnv(gym.Env):
return True
return False
def _is_in_goal(self, pos):
(np.linalg.norm(obs[:3] - goal) <= 0.6)
def step(self, action):
self.t += 1
if self._manual_collision:
old_pos = self.wrapped_env.get_xy()
inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
new_pos = self.wrapped_env.get_xy()
if self._is_in_collision(new_pos):
self.wrapped_env.set_xy(old_pos)
else:
inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
next_obs = self._get_obs()
return next_obs, inner_reward, False, info
outer_reward = self._reward_fn(next_obs, self.goal)
done = self._goal_fn(next_obs, self.goal)
return next_obs, inner_reward + outer_reward, done, info
def _goal_fn(maze_id: str) -> callable:
if maze_id in ["Maze", "Push"]:
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
elif maze_id == "Fall":
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")
def _reward_fn(maze_id: str, dense: str) -> callable:
if dense:
if maze_id in ["Maze", "Push"]:
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
elif maze_id == "Fall":
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")
else:
if maze_id in ["Maze", "Push"]:
return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
elif maze_id == "Fall":
return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")
def _default_goal(maze_id: str) -> np.ndarray:
if maze_id == "Maze":
return np.array([0.0, 8.0])
elif maze_id == "Push":
return np.array([0.0, 19.0])
elif maze_id == "Fall":
return np.array([0.0, 27.0, 4.5])
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")