Implement a goal base reward function
This commit is contained in:
		
							parent
							
								
									7c20df20d7
								
							
						
					
					
						commit
						38f87fbb2d
					
				@ -3,11 +3,18 @@ import gym
 | 
				
			|||||||
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
 | 
					MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_kwargs(maze_id: str) -> tuple:
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        "observe_blocks": maze_id in ["Block", "BlockMaze"],
 | 
				
			||||||
 | 
					        "pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
for maze_id in MAZE_IDS:
 | 
					for maze_id in MAZE_IDS:
 | 
				
			||||||
    gym.envs.register(
 | 
					    gym.envs.register(
 | 
				
			||||||
        id="AntMaze{}-v0".format(maze_id),
 | 
					        id="AntMaze{}-v0".format(maze_id),
 | 
				
			||||||
        entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
 | 
					        entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
 | 
				
			||||||
        kwargs=dict(maze_id=maze_id, manual_collision=True),
 | 
					        kwargs=dict(maze_id=maze_id, maze_size_scaling=8, **_get_kwargs(maze_id)),
 | 
				
			||||||
        max_episode_steps=1000,
 | 
					        max_episode_steps=1000,
 | 
				
			||||||
        reward_threshold=-1000,
 | 
					        reward_threshold=-1000,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@ -16,7 +23,12 @@ for maze_id in MAZE_IDS:
 | 
				
			|||||||
    gym.envs.register(
 | 
					    gym.envs.register(
 | 
				
			||||||
        id="PointMaze{}-v0".format(maze_id),
 | 
					        id="PointMaze{}-v0".format(maze_id),
 | 
				
			||||||
        entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
 | 
					        entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
 | 
				
			||||||
        kwargs=dict(maze_id=maze_id, manual_collision=True),
 | 
					        kwargs=dict(
 | 
				
			||||||
 | 
					            maze_id=maze_id,
 | 
				
			||||||
 | 
					            maze_size_scaling=4,
 | 
				
			||||||
 | 
					            manual_collision=True,
 | 
				
			||||||
 | 
					            **_get_kwargs(maze_id),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
        max_episode_steps=1000,
 | 
					        max_episode_steps=1000,
 | 
				
			||||||
        reward_threshold=-1000,
 | 
					        reward_threshold=-1000,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
@ -36,4 +36,3 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
 | 
				
			|||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def get_ori(self) -> float:
 | 
					    def get_ori(self) -> float:
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -126,7 +126,7 @@ class AntEnv(AgentModel):
 | 
				
			|||||||
    def get_ori(self):
 | 
					    def get_ori(self):
 | 
				
			||||||
        ori = [0, 1, 0, 0]
 | 
					        ori = [0, 1, 0, 0]
 | 
				
			||||||
        ori_ind = self.ORI_IND
 | 
					        ori_ind = self.ORI_IND
 | 
				
			||||||
        rot = self.sim.data.qpos[ori_ind: ori_ind + 4]  # take the quaternion
 | 
					        rot = self.sim.data.qpos[ori_ind : ori_ind + 4]  # take the quaternion
 | 
				
			||||||
        ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3]  # project onto x-y plane
 | 
					        ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3]  # project onto x-y plane
 | 
				
			||||||
        ori = math.atan2(ori[1], ori[0])
 | 
					        ori = math.atan2(ori[1], ori[0])
 | 
				
			||||||
        return ori
 | 
					        return ori
 | 
				
			||||||
 | 
				
			|||||||
@ -22,7 +22,7 @@ import math
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import gym
 | 
					import gym
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Type
 | 
					from typing import Callable, Type, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mujoco_maze.agent_model import AgentModel
 | 
					from mujoco_maze.agent_model import AgentModel
 | 
				
			||||||
from mujoco_maze import maze_env_utils
 | 
					from mujoco_maze import maze_env_utils
 | 
				
			||||||
@ -49,6 +49,8 @@ class MazeEnv(gym.Env):
 | 
				
			|||||||
        put_spin_near_agent=False,
 | 
					        put_spin_near_agent=False,
 | 
				
			||||||
        top_down_view=False,
 | 
					        top_down_view=False,
 | 
				
			||||||
        manual_collision=False,
 | 
					        manual_collision=False,
 | 
				
			||||||
 | 
					        dense_reward=True,
 | 
				
			||||||
 | 
					        goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default",
 | 
				
			||||||
        *args,
 | 
					        *args,
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
@ -162,7 +164,7 @@ class MazeEnv(gym.Env):
 | 
				
			|||||||
                    )
 | 
					                    )
 | 
				
			||||||
                elif maze_env_utils.can_move(struct):  # Movable block.
 | 
					                elif maze_env_utils.can_move(struct):  # Movable block.
 | 
				
			||||||
                    # The "falling" blocks are shrunk slightly and increased in mass to
 | 
					                    # The "falling" blocks are shrunk slightly and increased in mass to
 | 
				
			||||||
                    # ensure that it can fall easily through a gap in the platform blocks.
 | 
					                    # ensure it can fall easily through a gap in the platform blocks.
 | 
				
			||||||
                    name = "movable_%d_%d" % (i, j)
 | 
					                    name = "movable_%d_%d" % (i, j)
 | 
				
			||||||
                    self.movable_blocks.append((name, struct))
 | 
					                    self.movable_blocks.append((name, struct))
 | 
				
			||||||
                    falling = maze_env_utils.can_move_z(struct)
 | 
					                    falling = maze_env_utils.can_move_z(struct)
 | 
				
			||||||
@ -265,6 +267,29 @@ class MazeEnv(gym.Env):
 | 
				
			|||||||
        tree.write(file_path)
 | 
					        tree.write(file_path)
 | 
				
			||||||
        self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
 | 
					        self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Set reward function
 | 
				
			||||||
 | 
					        self._reward_fn = _reward_fn(maze_id, dense_reward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Set goal sampler
 | 
				
			||||||
 | 
					        if isinstance(goal_sampler, str):
 | 
				
			||||||
 | 
					            if goal_sampler == "random":
 | 
				
			||||||
 | 
					                self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20))
 | 
				
			||||||
 | 
					            elif goal_sampler == "default":
 | 
				
			||||||
 | 
					                default_goal = _default_goal(maze_id)
 | 
				
			||||||
 | 
					                self._goal_sampler = lambda: default_goal
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}")
 | 
				
			||||||
 | 
					        elif isinstance(goal_sampler, np.ndarray):
 | 
				
			||||||
 | 
					            self._goal_sampler = lambda: goal_sampler
 | 
				
			||||||
 | 
					        elif callable(goal_sampler):
 | 
				
			||||||
 | 
					            self._goal_sampler = goal_sampler
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f"Invalid goal_sampler: {goal_sampler}")
 | 
				
			||||||
 | 
					        self.goal = self._goal_sampler()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Set goal function
 | 
				
			||||||
 | 
					        self._goal_fn = _goal_fn(maze_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_ori(self):
 | 
					    def get_ori(self):
 | 
				
			||||||
        return self.wrapped_env.get_ori()
 | 
					        return self.wrapped_env.get_ori()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -472,6 +497,8 @@ class MazeEnv(gym.Env):
 | 
				
			|||||||
    def reset(self):
 | 
					    def reset(self):
 | 
				
			||||||
        self.t = 0
 | 
					        self.t = 0
 | 
				
			||||||
        self.wrapped_env.reset()
 | 
					        self.wrapped_env.reset()
 | 
				
			||||||
 | 
					        # Sample a new goal
 | 
				
			||||||
 | 
					        self.goal = self._goal_sampler()
 | 
				
			||||||
        if len(self._init_positions) > 1:
 | 
					        if len(self._init_positions) > 1:
 | 
				
			||||||
            xy = np.random.choice(self._init_positions)
 | 
					            xy = np.random.choice(self._init_positions)
 | 
				
			||||||
            self.wrapped_env.set_xy(xy)
 | 
					            self.wrapped_env.set_xy(xy)
 | 
				
			||||||
@ -529,15 +556,57 @@ class MazeEnv(gym.Env):
 | 
				
			|||||||
                        return True
 | 
					                        return True
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _is_in_goal(self, pos):
 | 
				
			||||||
 | 
					        (np.linalg.norm(obs[:3] - goal) <= 0.6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def step(self, action):
 | 
					    def step(self, action):
 | 
				
			||||||
        self.t += 1
 | 
					        self.t += 1
 | 
				
			||||||
        if self._manual_collision:
 | 
					        if self._manual_collision:
 | 
				
			||||||
            old_pos = self.wrapped_env.get_xy()
 | 
					            old_pos = self.wrapped_env.get_xy()
 | 
				
			||||||
            inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
 | 
					            inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
 | 
				
			||||||
            new_pos = self.wrapped_env.get_xy()
 | 
					            new_pos = self.wrapped_env.get_xy()
 | 
				
			||||||
            if self._is_in_collision(new_pos):
 | 
					            if self._is_in_collision(new_pos):
 | 
				
			||||||
                self.wrapped_env.set_xy(old_pos)
 | 
					                self.wrapped_env.set_xy(old_pos)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
 | 
					            inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
 | 
				
			||||||
        next_obs = self._get_obs()
 | 
					        next_obs = self._get_obs()
 | 
				
			||||||
        return next_obs, inner_reward, False, info
 | 
					        outer_reward = self._reward_fn(next_obs, self.goal)
 | 
				
			||||||
 | 
					        done = self._goal_fn(next_obs, self.goal)
 | 
				
			||||||
 | 
					        return next_obs, inner_reward + outer_reward, done, info
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _goal_fn(maze_id: str) -> callable:
 | 
				
			||||||
 | 
					    if maze_id in ["Maze", "Push"]:
 | 
				
			||||||
 | 
					        return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
 | 
				
			||||||
 | 
					    elif maze_id == "Fall":
 | 
				
			||||||
 | 
					        return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _reward_fn(maze_id: str, dense: str) -> callable:
 | 
				
			||||||
 | 
					    if dense:
 | 
				
			||||||
 | 
					        if maze_id in ["Maze", "Push"]:
 | 
				
			||||||
 | 
					            return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
 | 
				
			||||||
 | 
					        elif maze_id == "Fall":
 | 
				
			||||||
 | 
					            return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if maze_id in ["Maze", "Push"]:
 | 
				
			||||||
 | 
					            return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
 | 
				
			||||||
 | 
					        elif maze_id == "Fall":
 | 
				
			||||||
 | 
					            return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _default_goal(maze_id: str) -> np.ndarray:
 | 
				
			||||||
 | 
					    if maze_id == "Maze":
 | 
				
			||||||
 | 
					        return np.array([0.0, 8.0])
 | 
				
			||||||
 | 
					    elif maze_id == "Push":
 | 
				
			||||||
 | 
					        return np.array([0.0, 19.0])
 | 
				
			||||||
 | 
					    elif maze_id == "Fall":
 | 
				
			||||||
 | 
					        return np.array([0.0, 27.0, 4.5])
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user