Implement a goal base reward function
This commit is contained in:
parent
7c20df20d7
commit
38f87fbb2d
@ -3,11 +3,18 @@ import gym
|
||||
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
|
||||
|
||||
|
||||
def _get_kwargs(maze_id: str) -> tuple:
|
||||
return {
|
||||
"observe_blocks": maze_id in ["Block", "BlockMaze"],
|
||||
"pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
|
||||
}
|
||||
|
||||
|
||||
for maze_id in MAZE_IDS:
|
||||
gym.envs.register(
|
||||
id="AntMaze{}-v0".format(maze_id),
|
||||
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
||||
kwargs=dict(maze_id=maze_id, manual_collision=True),
|
||||
kwargs=dict(maze_id=maze_id, maze_size_scaling=8, **_get_kwargs(maze_id)),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=-1000,
|
||||
)
|
||||
@ -16,7 +23,12 @@ for maze_id in MAZE_IDS:
|
||||
gym.envs.register(
|
||||
id="PointMaze{}-v0".format(maze_id),
|
||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
||||
kwargs=dict(maze_id=maze_id, manual_collision=True),
|
||||
kwargs=dict(
|
||||
maze_id=maze_id,
|
||||
maze_size_scaling=4,
|
||||
manual_collision=True,
|
||||
**_get_kwargs(maze_id),
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=-1000,
|
||||
)
|
||||
|
@ -36,4 +36,3 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
|
||||
@abstractmethod
|
||||
def get_ori(self) -> float:
|
||||
pass
|
||||
|
||||
|
@ -126,7 +126,7 @@ class AntEnv(AgentModel):
|
||||
def get_ori(self):
|
||||
ori = [0, 1, 0, 0]
|
||||
ori_ind = self.ORI_IND
|
||||
rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion
|
||||
rot = self.sim.data.qpos[ori_ind : ori_ind + 4] # take the quaternion
|
||||
ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane
|
||||
ori = math.atan2(ori[1], ori[0])
|
||||
return ori
|
||||
|
@ -22,7 +22,7 @@ import math
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from typing import Type
|
||||
from typing import Callable, Type, Union
|
||||
|
||||
from mujoco_maze.agent_model import AgentModel
|
||||
from mujoco_maze import maze_env_utils
|
||||
@ -49,6 +49,8 @@ class MazeEnv(gym.Env):
|
||||
put_spin_near_agent=False,
|
||||
top_down_view=False,
|
||||
manual_collision=False,
|
||||
dense_reward=True,
|
||||
goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@ -162,7 +164,7 @@ class MazeEnv(gym.Env):
|
||||
)
|
||||
elif maze_env_utils.can_move(struct): # Movable block.
|
||||
# The "falling" blocks are shrunk slightly and increased in mass to
|
||||
# ensure that it can fall easily through a gap in the platform blocks.
|
||||
# ensure it can fall easily through a gap in the platform blocks.
|
||||
name = "movable_%d_%d" % (i, j)
|
||||
self.movable_blocks.append((name, struct))
|
||||
falling = maze_env_utils.can_move_z(struct)
|
||||
@ -265,6 +267,29 @@ class MazeEnv(gym.Env):
|
||||
tree.write(file_path)
|
||||
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
|
||||
|
||||
# Set reward function
|
||||
self._reward_fn = _reward_fn(maze_id, dense_reward)
|
||||
|
||||
# Set goal sampler
|
||||
if isinstance(goal_sampler, str):
|
||||
if goal_sampler == "random":
|
||||
self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20))
|
||||
elif goal_sampler == "default":
|
||||
default_goal = _default_goal(maze_id)
|
||||
self._goal_sampler = lambda: default_goal
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}")
|
||||
elif isinstance(goal_sampler, np.ndarray):
|
||||
self._goal_sampler = lambda: goal_sampler
|
||||
elif callable(goal_sampler):
|
||||
self._goal_sampler = goal_sampler
|
||||
else:
|
||||
raise ValueError(f"Invalid goal_sampler: {goal_sampler}")
|
||||
self.goal = self._goal_sampler()
|
||||
|
||||
# Set goal function
|
||||
self._goal_fn = _goal_fn(maze_id)
|
||||
|
||||
def get_ori(self):
|
||||
return self.wrapped_env.get_ori()
|
||||
|
||||
@ -472,6 +497,8 @@ class MazeEnv(gym.Env):
|
||||
def reset(self):
|
||||
self.t = 0
|
||||
self.wrapped_env.reset()
|
||||
# Sample a new goal
|
||||
self.goal = self._goal_sampler()
|
||||
if len(self._init_positions) > 1:
|
||||
xy = np.random.choice(self._init_positions)
|
||||
self.wrapped_env.set_xy(xy)
|
||||
@ -529,15 +556,57 @@ class MazeEnv(gym.Env):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_in_goal(self, pos):
|
||||
(np.linalg.norm(obs[:3] - goal) <= 0.6)
|
||||
|
||||
def step(self, action):
|
||||
self.t += 1
|
||||
if self._manual_collision:
|
||||
old_pos = self.wrapped_env.get_xy()
|
||||
inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
|
||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||
new_pos = self.wrapped_env.get_xy()
|
||||
if self._is_in_collision(new_pos):
|
||||
self.wrapped_env.set_xy(old_pos)
|
||||
else:
|
||||
inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
|
||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||
next_obs = self._get_obs()
|
||||
return next_obs, inner_reward, False, info
|
||||
outer_reward = self._reward_fn(next_obs, self.goal)
|
||||
done = self._goal_fn(next_obs, self.goal)
|
||||
return next_obs, inner_reward + outer_reward, done, info
|
||||
|
||||
|
||||
def _goal_fn(maze_id: str) -> callable:
|
||||
if maze_id in ["Maze", "Push"]:
|
||||
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
|
||||
elif maze_id == "Fall":
|
||||
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||
|
||||
|
||||
def _reward_fn(maze_id: str, dense: str) -> callable:
|
||||
if dense:
|
||||
if maze_id in ["Maze", "Push"]:
|
||||
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
|
||||
elif maze_id == "Fall":
|
||||
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||
else:
|
||||
if maze_id in ["Maze", "Push"]:
|
||||
return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
|
||||
elif maze_id == "Fall":
|
||||
return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||
|
||||
|
||||
def _default_goal(maze_id: str) -> np.ndarray:
|
||||
if maze_id == "Maze":
|
||||
return np.array([0.0, 8.0])
|
||||
elif maze_id == "Push":
|
||||
return np.array([0.0, 19.0])
|
||||
elif maze_id == "Fall":
|
||||
return np.array([0.0, 27.0, 4.5])
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown maze id: {maze_id}")
|
||||
|
Loading…
Reference in New Issue
Block a user