Merge pull request #55 from HongyiZhouCN/fix_bugs_for_replanning
Fix bugs for replanning
This commit is contained in:
		
						commit
						61626135ff
					
				@ -22,7 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
				
			|||||||
                 replanning_schedule: Optional[
 | 
					                 replanning_schedule: Optional[
 | 
				
			||||||
                     Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
 | 
					                     Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
 | 
				
			||||||
                 reward_aggregation: Callable[[np.ndarray], float] = np.sum,
 | 
					                 reward_aggregation: Callable[[np.ndarray], float] = np.sum,
 | 
				
			||||||
                 max_planning_times: int = None,
 | 
					                 max_planning_times: int = np.inf,
 | 
				
			||||||
                 condition_on_desired: bool = False
 | 
					                 condition_on_desired: bool = False
 | 
				
			||||||
                 ):
 | 
					                 ):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -161,9 +161,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.plan_steps += 1
 | 
					        self.plan_steps += 1
 | 
				
			||||||
        for t, (pos, vel) in enumerate(zip(position, velocity)):
 | 
					        for t, (pos, vel) in enumerate(zip(position, velocity)):
 | 
				
			||||||
            current_pos = self.current_pos
 | 
					            step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
 | 
				
			||||||
            current_vel = self.current_vel
 | 
					 | 
				
			||||||
            step_action = self.tracking_controller.get_action(pos, vel, current_pos, current_vel)
 | 
					 | 
				
			||||||
            c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
 | 
					            c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
 | 
				
			||||||
            obs, c_reward, done, info = self.env.step(c_action)
 | 
					            obs, c_reward, done, info = self.env.step(c_action)
 | 
				
			||||||
            rewards[t] = c_reward
 | 
					            rewards[t] = c_reward
 | 
				
			||||||
@ -180,11 +178,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
				
			|||||||
            if self.render_kwargs:
 | 
					            if self.render_kwargs:
 | 
				
			||||||
                self.env.render(**self.render_kwargs)
 | 
					                self.env.render(**self.render_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if done or self.replanning_schedule(current_pos, current_vel, obs, c_action,
 | 
					            if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
 | 
				
			||||||
                                                t + 1 + self.current_traj_steps):
 | 
					                                                 t + 1 + self.current_traj_steps)
 | 
				
			||||||
 | 
					                        and self.plan_steps < self.max_planning_times):
 | 
				
			||||||
                if self.max_planning_times is not None and self.plan_steps >= self.max_planning_times:
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                self.condition_pos = pos if self.condition_on_desired else None
 | 
					                self.condition_pos = pos if self.condition_on_desired else None
 | 
				
			||||||
                self.condition_vel = vel if self.condition_on_desired else None
 | 
					                self.condition_vel = vel if self.condition_on_desired else None
 | 
				
			||||||
@ -214,4 +210,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
				
			|||||||
        self.current_traj_steps = 0
 | 
					        self.current_traj_steps = 0
 | 
				
			||||||
        self.plan_steps = 0
 | 
					        self.plan_steps = 0
 | 
				
			||||||
        self.traj_gen.reset()
 | 
					        self.traj_gen.reset()
 | 
				
			||||||
 | 
					        self.condition_vel = None
 | 
				
			||||||
 | 
					        self.condition_pos = None
 | 
				
			||||||
        return super(BlackBoxWrapper, self).reset()
 | 
					        return super(BlackBoxWrapper, self).reset()
 | 
				
			||||||
 | 
				
			|||||||
@ -498,7 +498,7 @@ for _v in _versions:
 | 
				
			|||||||
    kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
 | 
					    kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
 | 
				
			||||||
    kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
 | 
					    kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
 | 
				
			||||||
    kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
 | 
					    kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
 | 
				
			||||||
    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
 | 
					    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
 | 
				
			||||||
    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
 | 
					    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
 | 
				
			||||||
    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True
 | 
					    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True
 | 
				
			||||||
    register(
 | 
					    register(
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user