check time validity before pos validity
This commit is contained in:
		
							parent
							
								
									d384e6e764
								
							
						
					
					
						commit
						b9c2348855
					
				@ -164,95 +164,101 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
 | 
			
		||||
    def step(self, action: np.ndarray):
 | 
			
		||||
        """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
 | 
			
		||||
        time_valid = self.env.check_time_validity(action)
 | 
			
		||||
 | 
			
		||||
        if self.plan_counts == 0:
 | 
			
		||||
            self.tau_first_prediction = action[0]
 | 
			
		||||
        if time_valid:
 | 
			
		||||
 | 
			
		||||
        ## tricky part, only use weights basis
 | 
			
		||||
        # basis_weights = action.reshape(7, -1)
 | 
			
		||||
        # goal_weights = np.zeros((7, 1))
 | 
			
		||||
        # action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
 | 
			
		||||
            if self.plan_counts == 0:
 | 
			
		||||
                self.tau_first_prediction = action[0]
 | 
			
		||||
 | 
			
		||||
        # TODO remove this part, right now only needed for beer pong
 | 
			
		||||
        # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
 | 
			
		||||
        position, velocity = self.get_trajectory(action)
 | 
			
		||||
        traj_is_valid = self.env.episode_callback(action, position, velocity)
 | 
			
		||||
            ## tricky part, only use weights basis
 | 
			
		||||
            # basis_weights = action.reshape(7, -1)
 | 
			
		||||
            # goal_weights = np.zeros((7, 1))
 | 
			
		||||
            # action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
 | 
			
		||||
 | 
			
		||||
        trajectory_length = len(position)
 | 
			
		||||
        rewards = np.zeros(shape=(trajectory_length,))
 | 
			
		||||
        if self.verbose >= 2:
 | 
			
		||||
            actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
 | 
			
		||||
            observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
 | 
			
		||||
                                    dtype=self.env.observation_space.dtype)
 | 
			
		||||
            # TODO remove this part, right now only needed for beer pong
 | 
			
		||||
            # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
 | 
			
		||||
            position, velocity = self.get_trajectory(action)
 | 
			
		||||
            traj_is_valid = self.env.episode_callback(action, position, velocity)
 | 
			
		||||
 | 
			
		||||
        infos = dict()
 | 
			
		||||
        done = False
 | 
			
		||||
            trajectory_length = len(position)
 | 
			
		||||
            rewards = np.zeros(shape=(trajectory_length,))
 | 
			
		||||
            if self.verbose >= 2:
 | 
			
		||||
                actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
 | 
			
		||||
                observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
 | 
			
		||||
                                        dtype=self.env.observation_space.dtype)
 | 
			
		||||
 | 
			
		||||
        if self.verbose >= 2:
 | 
			
		||||
            desired_pos_traj = []
 | 
			
		||||
            desired_vel_traj = []
 | 
			
		||||
            pos_traj = []
 | 
			
		||||
            vel_traj = []
 | 
			
		||||
 | 
			
		||||
        if traj_is_valid:
 | 
			
		||||
            self.plan_counts += 1
 | 
			
		||||
            for t, (pos, vel) in enumerate(zip(position, velocity)):
 | 
			
		||||
                step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
 | 
			
		||||
                c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
 | 
			
		||||
                obs, c_reward, done, info = self.env.step(c_action)
 | 
			
		||||
                rewards[t] = c_reward
 | 
			
		||||
 | 
			
		||||
                if self.verbose >= 2:
 | 
			
		||||
                    actions[t, :] = c_action
 | 
			
		||||
                    observations[t, :] = obs
 | 
			
		||||
 | 
			
		||||
                for k, v in info.items():
 | 
			
		||||
                    elems = infos.get(k, [None] * trajectory_length)
 | 
			
		||||
                    elems[t] = v
 | 
			
		||||
                    infos[k] = elems
 | 
			
		||||
 | 
			
		||||
                if self.verbose >= 2:
 | 
			
		||||
                    desired_pos_traj.append(pos)
 | 
			
		||||
                    desired_vel_traj.append(vel)
 | 
			
		||||
                    pos_traj.append(self.current_pos)
 | 
			
		||||
                    vel_traj.append(self.current_vel)
 | 
			
		||||
 | 
			
		||||
                if self.render_kwargs:
 | 
			
		||||
                    self.env.render(**self.render_kwargs)
 | 
			
		||||
 | 
			
		||||
                if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
 | 
			
		||||
                                                    t + 1 + self.current_traj_steps):
 | 
			
		||||
 | 
			
		||||
                    # if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
 | 
			
		||||
                    #     continue
 | 
			
		||||
 | 
			
		||||
                    self.condition_pos = pos if self.desired_conditioning else self.current_pos
 | 
			
		||||
                    self.condition_vel = vel if self.desired_conditioning else self.current_vel
 | 
			
		||||
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
            infos.update({k: v[:t+1] for k, v in infos.items()})
 | 
			
		||||
            self.current_traj_steps += t + 1
 | 
			
		||||
            infos = dict()
 | 
			
		||||
            done = False
 | 
			
		||||
 | 
			
		||||
            if self.verbose >= 2:
 | 
			
		||||
                infos['desired_pos'] = position[:t+1]
 | 
			
		||||
                infos['desired_vel'] = velocity[:t+1]
 | 
			
		||||
                infos['current_pos'] = self.current_pos
 | 
			
		||||
                infos['current_vel'] = self.current_vel
 | 
			
		||||
                infos['step_actions'] = actions[:t + 1]
 | 
			
		||||
                infos['step_observations'] = observations[:t + 1]
 | 
			
		||||
                infos['step_rewards'] = rewards[:t + 1]
 | 
			
		||||
                infos['desired_pos_traj'] = np.array(desired_pos_traj)
 | 
			
		||||
                infos['desired_vel_traj'] = np.array(desired_vel_traj)
 | 
			
		||||
                infos['pos_traj'] = np.array(pos_traj)
 | 
			
		||||
                infos['vel_traj'] = np.array(vel_traj)
 | 
			
		||||
                desired_pos_traj = []
 | 
			
		||||
                desired_vel_traj = []
 | 
			
		||||
                pos_traj = []
 | 
			
		||||
                vel_traj = []
 | 
			
		||||
 | 
			
		||||
            infos['trajectory_length'] = t + 1
 | 
			
		||||
            trajectory_return = self.reward_aggregation(rewards[:t + 1])
 | 
			
		||||
            return self.observation(obs), trajectory_return, done, infos
 | 
			
		||||
            if traj_is_valid:
 | 
			
		||||
                self.plan_counts += 1
 | 
			
		||||
                for t, (pos, vel) in enumerate(zip(position, velocity)):
 | 
			
		||||
                    step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
 | 
			
		||||
                    c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
 | 
			
		||||
                    obs, c_reward, done, info = self.env.step(c_action)
 | 
			
		||||
                    rewards[t] = c_reward
 | 
			
		||||
 | 
			
		||||
                    if self.verbose >= 2:
 | 
			
		||||
                        actions[t, :] = c_action
 | 
			
		||||
                        observations[t, :] = obs
 | 
			
		||||
 | 
			
		||||
                    for k, v in info.items():
 | 
			
		||||
                        elems = infos.get(k, [None] * trajectory_length)
 | 
			
		||||
                        elems[t] = v
 | 
			
		||||
                        infos[k] = elems
 | 
			
		||||
 | 
			
		||||
                    if self.verbose >= 2:
 | 
			
		||||
                        desired_pos_traj.append(pos)
 | 
			
		||||
                        desired_vel_traj.append(vel)
 | 
			
		||||
                        pos_traj.append(self.current_pos)
 | 
			
		||||
                        vel_traj.append(self.current_vel)
 | 
			
		||||
 | 
			
		||||
                    if self.render_kwargs:
 | 
			
		||||
                        self.env.render(**self.render_kwargs)
 | 
			
		||||
 | 
			
		||||
                    if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
 | 
			
		||||
                                                        t + 1 + self.current_traj_steps):
 | 
			
		||||
 | 
			
		||||
                        # if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
 | 
			
		||||
                        #     continue
 | 
			
		||||
 | 
			
		||||
                        self.condition_pos = pos if self.desired_conditioning else self.current_pos
 | 
			
		||||
                        self.condition_vel = vel if self.desired_conditioning else self.current_vel
 | 
			
		||||
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                infos.update({k: v[:t+1] for k, v in infos.items()})
 | 
			
		||||
                self.current_traj_steps += t + 1
 | 
			
		||||
 | 
			
		||||
                if self.verbose >= 2:
 | 
			
		||||
                    infos['desired_pos'] = position[:t+1]
 | 
			
		||||
                    infos['desired_vel'] = velocity[:t+1]
 | 
			
		||||
                    infos['current_pos'] = self.current_pos
 | 
			
		||||
                    infos['current_vel'] = self.current_vel
 | 
			
		||||
                    infos['step_actions'] = actions[:t + 1]
 | 
			
		||||
                    infos['step_observations'] = observations[:t + 1]
 | 
			
		||||
                    infos['step_rewards'] = rewards[:t + 1]
 | 
			
		||||
                    infos['desired_pos_traj'] = np.array(desired_pos_traj)
 | 
			
		||||
                    infos['desired_vel_traj'] = np.array(desired_vel_traj)
 | 
			
		||||
                    infos['pos_traj'] = np.array(pos_traj)
 | 
			
		||||
                    infos['vel_traj'] = np.array(vel_traj)
 | 
			
		||||
 | 
			
		||||
                infos['trajectory_length'] = t + 1
 | 
			
		||||
                trajectory_return = self.reward_aggregation(rewards[:t + 1])
 | 
			
		||||
                return self.observation(obs), trajectory_return, done, infos
 | 
			
		||||
            else:
 | 
			
		||||
                obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
 | 
			
		||||
                return self.observation(obs), trajectory_return, done, infos
 | 
			
		||||
        else:
 | 
			
		||||
            obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
 | 
			
		||||
            return self.observation(obs), trajectory_return, done, infos
 | 
			
		||||
                obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
 | 
			
		||||
                return self.observation(obs), trajectory_return, done, infos
 | 
			
		||||
    def render(self, **kwargs):
 | 
			
		||||
        """Only set render options here, such that they can be used during the rollout.
 | 
			
		||||
        This only needs to be called once"""
 | 
			
		||||
 | 
			
		||||
@ -560,7 +560,7 @@ for _v in _versions:
 | 
			
		||||
    _name = _v.split("-")
 | 
			
		||||
    _env_id = f'{_name[0]}ProDMP-{_name[1]}'
 | 
			
		||||
    kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
 | 
			
		||||
    kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
 | 
			
		||||
    kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper)
 | 
			
		||||
    kwargs_dict_tt_prodmp['name'] = _v
 | 
			
		||||
    kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
 | 
			
		||||
    kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
 | 
			
		||||
 | 
			
		||||
@ -28,12 +28,31 @@ class MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        return self.data.qvel[:7].copy()
 | 
			
		||||
 | 
			
		||||
    def check_time_validity(self, action):
 | 
			
		||||
        return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
 | 
			
		||||
               and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
 | 
			
		||||
 | 
			
		||||
    def time_invalid_traj_callback(self, action) \
 | 
			
		||||
        -> Tuple[np.ndarray, float, bool, dict]:
 | 
			
		||||
        tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
 | 
			
		||||
        delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
 | 
			
		||||
        invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
 | 
			
		||||
        obs = np.concatenate([self.get_obs(), np.array([0])])
 | 
			
		||||
        return obs, -invalid_penalty, True, {
 | 
			
		||||
        "hit_ball": [False],
 | 
			
		||||
        "ball_returned_success": [False],
 | 
			
		||||
        "land_dist_error": [10.],
 | 
			
		||||
        "is_success": [False],
 | 
			
		||||
        'trajectory_length': 1,
 | 
			
		||||
        "num_steps": [1]
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def episode_callback(self, action, pos_traj, vel_traj):
 | 
			
		||||
        time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
 | 
			
		||||
                     or action[1] > delay_bound[1] or action[1] < delay_bound[0]
 | 
			
		||||
        if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
 | 
			
		||||
            return False
 | 
			
		||||
        return True
 | 
			
		||||
            return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
 | 
			
		||||
            -> Tuple[np.ndarray, float, bool, dict]:
 | 
			
		||||
@ -43,7 +62,8 @@ class MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
        violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
 | 
			
		||||
        invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
 | 
			
		||||
                          violate_high_bound_error + violate_low_bound_error
 | 
			
		||||
        return self.get_obs(), -invalid_penalty, True, {
 | 
			
		||||
        obs = np.concatenate([self.get_obs(), np.array([0])])
 | 
			
		||||
        return obs, -invalid_penalty, True, {
 | 
			
		||||
        "hit_ball": [False],
 | 
			
		||||
        "ball_returned_success": [False],
 | 
			
		||||
        "land_dist_error": [10.],
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user