add contextual obs option to invalid trajectory callback
This commit is contained in:
		
							parent
							
								
									c242c32a41
								
							
						
					
					
						commit
						2735e0bf24
					
				@ -169,8 +169,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
        infos = dict()
 | 
			
		||||
        done = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if traj_is_valid:
 | 
			
		||||
            self.plan_steps += 1
 | 
			
		||||
            for t, (pos, vel) in enumerate(zip(position, velocity)):
 | 
			
		||||
@ -207,18 +205,19 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
            infos.update({k: v[:t+1] for k, v in infos.items()})
 | 
			
		||||
            self.current_traj_steps += t + 1
 | 
			
		||||
 | 
			
		||||
        if self.verbose >= 2:
 | 
			
		||||
            infos['positions'] = position
 | 
			
		||||
            infos['velocities'] = velocity
 | 
			
		||||
            infos['step_actions'] = actions[:t + 1]
 | 
			
		||||
            infos['step_observations'] = observations[:t + 1]
 | 
			
		||||
            infos['step_rewards'] = rewards[:t + 1]
 | 
			
		||||
            if self.verbose >= 2:
 | 
			
		||||
                infos['positions'] = position
 | 
			
		||||
                infos['velocities'] = velocity
 | 
			
		||||
                infos['step_actions'] = actions[:t + 1]
 | 
			
		||||
                infos['step_observations'] = observations[:t + 1]
 | 
			
		||||
                infos['step_rewards'] = rewards[:t + 1]
 | 
			
		||||
 | 
			
		||||
            infos['trajectory_length'] = t + 1
 | 
			
		||||
            trajectory_return = self.reward_aggregation(rewards[:t + 1])
 | 
			
		||||
            return self.observation(obs), trajectory_return, done, infos
 | 
			
		||||
        else:
 | 
			
		||||
            obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
 | 
			
		||||
            obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
 | 
			
		||||
                                                                                 self.return_context_observation)
 | 
			
		||||
            return self.observation(obs), trajectory_return, done, infos
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -55,8 +55,8 @@ class MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
            return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
 | 
			
		||||
            -> Tuple[np.ndarray, float, bool, dict]:
 | 
			
		||||
    def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray,
 | 
			
		||||
                              return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
 | 
			
		||||
        tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
 | 
			
		||||
        delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
 | 
			
		||||
        violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
 | 
			
		||||
@ -64,6 +64,8 @@ class MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
        invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
 | 
			
		||||
                          violate_high_bound_error + violate_low_bound_error
 | 
			
		||||
        obs = np.concatenate([self.get_obs(), np.array([0])])
 | 
			
		||||
        if return_contextual_obs:
 | 
			
		||||
            obs = self.get_obs()
 | 
			
		||||
        return obs, -invalid_penalty, True, {
 | 
			
		||||
        "hit_ball": [False],
 | 
			
		||||
        "ball_returned_success": [False],
 | 
			
		||||
 | 
			
		||||
@ -157,17 +157,18 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    render = True
 | 
			
		||||
    # DMP
 | 
			
		||||
    example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
 | 
			
		||||
    # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
 | 
			
		||||
 | 
			
		||||
    # ProMP
 | 
			
		||||
    example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
 | 
			
		||||
    example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
 | 
			
		||||
    # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
 | 
			
		||||
    # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
 | 
			
		||||
    example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
 | 
			
		||||
 | 
			
		||||
    # ProDMP
 | 
			
		||||
    example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
 | 
			
		||||
    # example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
 | 
			
		||||
 | 
			
		||||
    # Altered basis functions
 | 
			
		||||
    obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
 | 
			
		||||
    # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
 | 
			
		||||
 | 
			
		||||
    # Custom MP
 | 
			
		||||
    example_fully_custom_mp(seed=10, iterations=1, render=render)
 | 
			
		||||
    # example_fully_custom_mp(seed=10, iterations=1, render=render)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user