Implement a goal base reward function
This commit is contained in:
		
							parent
							
								
									7c20df20d7
								
							
						
					
					
						commit
						38f87fbb2d
					
				@ -3,11 +3,18 @@ import gym
 | 
			
		||||
MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_kwargs(maze_id: str) -> tuple:
 | 
			
		||||
    return {
 | 
			
		||||
        "observe_blocks": maze_id in ["Block", "BlockMaze"],
 | 
			
		||||
        "pin_spin_near_agent": maze_id in ["Block", "BlockMaze"],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
for maze_id in MAZE_IDS:
 | 
			
		||||
    gym.envs.register(
 | 
			
		||||
        id="AntMaze{}-v0".format(maze_id),
 | 
			
		||||
        entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
 | 
			
		||||
        kwargs=dict(maze_id=maze_id, manual_collision=True),
 | 
			
		||||
        kwargs=dict(maze_id=maze_id, maze_size_scaling=8, **_get_kwargs(maze_id)),
 | 
			
		||||
        max_episode_steps=1000,
 | 
			
		||||
        reward_threshold=-1000,
 | 
			
		||||
    )
 | 
			
		||||
@ -16,7 +23,12 @@ for maze_id in MAZE_IDS:
 | 
			
		||||
    gym.envs.register(
 | 
			
		||||
        id="PointMaze{}-v0".format(maze_id),
 | 
			
		||||
        entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
 | 
			
		||||
        kwargs=dict(maze_id=maze_id, manual_collision=True),
 | 
			
		||||
        kwargs=dict(
 | 
			
		||||
            maze_id=maze_id,
 | 
			
		||||
            maze_size_scaling=4,
 | 
			
		||||
            manual_collision=True,
 | 
			
		||||
            **_get_kwargs(maze_id),
 | 
			
		||||
        ),
 | 
			
		||||
        max_episode_steps=1000,
 | 
			
		||||
        reward_threshold=-1000,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -36,4 +36,3 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_ori(self) -> float:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ import math
 | 
			
		||||
import numpy as np
 | 
			
		||||
import gym
 | 
			
		||||
 | 
			
		||||
from typing import Type
 | 
			
		||||
from typing import Callable, Type, Union
 | 
			
		||||
 | 
			
		||||
from mujoco_maze.agent_model import AgentModel
 | 
			
		||||
from mujoco_maze import maze_env_utils
 | 
			
		||||
@ -49,6 +49,8 @@ class MazeEnv(gym.Env):
 | 
			
		||||
        put_spin_near_agent=False,
 | 
			
		||||
        top_down_view=False,
 | 
			
		||||
        manual_collision=False,
 | 
			
		||||
        dense_reward=True,
 | 
			
		||||
        goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default",
 | 
			
		||||
        *args,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
@ -162,7 +164,7 @@ class MazeEnv(gym.Env):
 | 
			
		||||
                    )
 | 
			
		||||
                elif maze_env_utils.can_move(struct):  # Movable block.
 | 
			
		||||
                    # The "falling" blocks are shrunk slightly and increased in mass to
 | 
			
		||||
                    # ensure that it can fall easily through a gap in the platform blocks.
 | 
			
		||||
                    # ensure it can fall easily through a gap in the platform blocks.
 | 
			
		||||
                    name = "movable_%d_%d" % (i, j)
 | 
			
		||||
                    self.movable_blocks.append((name, struct))
 | 
			
		||||
                    falling = maze_env_utils.can_move_z(struct)
 | 
			
		||||
@ -265,6 +267,29 @@ class MazeEnv(gym.Env):
 | 
			
		||||
        tree.write(file_path)
 | 
			
		||||
        self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Set reward function
 | 
			
		||||
        self._reward_fn = _reward_fn(maze_id, dense_reward)
 | 
			
		||||
 | 
			
		||||
        # Set goal sampler
 | 
			
		||||
        if isinstance(goal_sampler, str):
 | 
			
		||||
            if goal_sampler == "random":
 | 
			
		||||
                self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20))
 | 
			
		||||
            elif goal_sampler == "default":
 | 
			
		||||
                default_goal = _default_goal(maze_id)
 | 
			
		||||
                self._goal_sampler = lambda: default_goal
 | 
			
		||||
            else:
 | 
			
		||||
                raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}")
 | 
			
		||||
        elif isinstance(goal_sampler, np.ndarray):
 | 
			
		||||
            self._goal_sampler = lambda: goal_sampler
 | 
			
		||||
        elif callable(goal_sampler):
 | 
			
		||||
            self._goal_sampler = goal_sampler
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Invalid goal_sampler: {goal_sampler}")
 | 
			
		||||
        self.goal = self._goal_sampler()
 | 
			
		||||
 | 
			
		||||
        # Set goal function
 | 
			
		||||
        self._goal_fn = _goal_fn(maze_id)
 | 
			
		||||
 | 
			
		||||
    def get_ori(self):
 | 
			
		||||
        return self.wrapped_env.get_ori()
 | 
			
		||||
 | 
			
		||||
@ -472,6 +497,8 @@ class MazeEnv(gym.Env):
 | 
			
		||||
    def reset(self):
 | 
			
		||||
        self.t = 0
 | 
			
		||||
        self.wrapped_env.reset()
 | 
			
		||||
        # Sample a new goal
 | 
			
		||||
        self.goal = self._goal_sampler()
 | 
			
		||||
        if len(self._init_positions) > 1:
 | 
			
		||||
            xy = np.random.choice(self._init_positions)
 | 
			
		||||
            self.wrapped_env.set_xy(xy)
 | 
			
		||||
@ -529,15 +556,57 @@ class MazeEnv(gym.Env):
 | 
			
		||||
                        return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def _is_in_goal(self, pos):
 | 
			
		||||
        (np.linalg.norm(obs[:3] - goal) <= 0.6)
 | 
			
		||||
 | 
			
		||||
    def step(self, action):
 | 
			
		||||
        self.t += 1
 | 
			
		||||
        if self._manual_collision:
 | 
			
		||||
            old_pos = self.wrapped_env.get_xy()
 | 
			
		||||
            inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
 | 
			
		||||
            inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
 | 
			
		||||
            new_pos = self.wrapped_env.get_xy()
 | 
			
		||||
            if self._is_in_collision(new_pos):
 | 
			
		||||
                self.wrapped_env.set_xy(old_pos)
 | 
			
		||||
        else:
 | 
			
		||||
            inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
 | 
			
		||||
            inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
 | 
			
		||||
        next_obs = self._get_obs()
 | 
			
		||||
        return next_obs, inner_reward, False, info
 | 
			
		||||
        outer_reward = self._reward_fn(next_obs, self.goal)
 | 
			
		||||
        done = self._goal_fn(next_obs, self.goal)
 | 
			
		||||
        return next_obs, inner_reward + outer_reward, done, info
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _goal_fn(maze_id: str) -> callable:
 | 
			
		||||
    if maze_id in ["Maze", "Push"]:
 | 
			
		||||
        return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
 | 
			
		||||
    elif maze_id == "Fall":
 | 
			
		||||
        return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _reward_fn(maze_id: str, dense: str) -> callable:
 | 
			
		||||
    if dense:
 | 
			
		||||
        if maze_id in ["Maze", "Push"]:
 | 
			
		||||
            return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
 | 
			
		||||
        elif maze_id == "Fall":
 | 
			
		||||
            return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
			
		||||
    else:
 | 
			
		||||
        if maze_id in ["Maze", "Push"]:
 | 
			
		||||
            return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0
 | 
			
		||||
        elif maze_id == "Fall":
 | 
			
		||||
            return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _default_goal(maze_id: str) -> np.ndarray:
 | 
			
		||||
    if maze_id == "Maze":
 | 
			
		||||
        return np.array([0.0, 8.0])
 | 
			
		||||
    elif maze_id == "Push":
 | 
			
		||||
        return np.array([0.0, 19.0])
 | 
			
		||||
    elif maze_id == "Fall":
 | 
			
		||||
        return np.array([0.0, 27.0, 4.5])
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Unknown maze id: {maze_id}")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user