diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index d8dcbaa..66c5f3e 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -22,7 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): replanning_schedule: Optional[ Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, reward_aggregation: Callable[[np.ndarray], float] = np.sum, - max_planning_times: int = None, + max_planning_times: int = np.inf, condition_on_desired: bool = False ): """ @@ -178,11 +178,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): if self.render_kwargs: self.env.render(**self.render_kwargs) - if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, - t + 1 + self.current_traj_steps): - - if not done and self.max_planning_times is not None and self.plan_steps >= self.max_planning_times: - continue + if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, + t + 1 + self.current_traj_steps) + and self.plan_steps < self.max_planning_times): self.condition_pos = pos if self.condition_on_desired else None self.condition_vel = vel if self.condition_on_desired else None