Merge pull request #55 from HongyiZhouCN/fix_bugs_for_replanning

Fix bugs for replanning
This commit is contained in:
ottofabian 2022-11-28 13:53:08 +01:00 committed by GitHub
commit 61626135ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 10 deletions

View File

@ -22,7 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
replanning_schedule: Optional[ replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum, reward_aggregation: Callable[[np.ndarray], float] = np.sum,
max_planning_times: int = None, max_planning_times: int = np.inf,
condition_on_desired: bool = False condition_on_desired: bool = False
): ):
""" """
@ -161,9 +161,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.plan_steps += 1 self.plan_steps += 1
for t, (pos, vel) in enumerate(zip(position, velocity)): for t, (pos, vel) in enumerate(zip(position, velocity)):
current_pos = self.current_pos step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
current_vel = self.current_vel
step_action = self.tracking_controller.get_action(pos, vel, current_pos, current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
obs, c_reward, done, info = self.env.step(c_action) obs, c_reward, done, info = self.env.step(c_action)
rewards[t] = c_reward rewards[t] = c_reward
@ -180,11 +178,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
if self.render_kwargs: if self.render_kwargs:
self.env.render(**self.render_kwargs) self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(current_pos, current_vel, obs, c_action, if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps): t + 1 + self.current_traj_steps)
and self.plan_steps < self.max_planning_times):
if self.max_planning_times is not None and self.plan_steps >= self.max_planning_times:
continue
self.condition_pos = pos if self.condition_on_desired else None self.condition_pos = pos if self.condition_on_desired else None
self.condition_vel = vel if self.condition_on_desired else None self.condition_vel = vel if self.condition_on_desired else None
@ -214,4 +210,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.current_traj_steps = 0 self.current_traj_steps = 0
self.plan_steps = 0 self.plan_steps = 0
self.traj_gen.reset() self.traj_gen.reset()
self.condition_vel = None
self.condition_pos = None
return super(BlackBoxWrapper, self).reset() return super(BlackBoxWrapper, self).reset()

View File

@ -498,7 +498,7 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4 kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True
register( register(