diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 336ea44..4ff685a 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -22,7 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): replanning_schedule: Optional[ Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, reward_aggregation: Callable[[np.ndarray], float] = np.sum, - max_planning_times: int = None, + max_planning_times: int = np.inf, condition_on_desired: bool = False ): """ @@ -86,7 +86,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): def get_trajectory(self, action: np.ndarray) -> Tuple: duration = self.duration - # duration = self.duration - self.current_traj_steps * self.dt if self.learn_sub_trajectories: duration = None # reset with every new call as we need to set all arguments, such as tau, delay, again. @@ -187,11 +186,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): if self.render_kwargs: self.env.render(**self.render_kwargs) - if done or self.replanning_schedule(current_pos, current_vel, obs, c_action, - t + 1 + self.current_traj_steps): - - if not done and self.max_planning_times is not None and self.plan_steps >= self.max_planning_times: - continue + if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, + t + 1 + self.current_traj_steps) + and self.plan_steps < self.max_planning_times): self.condition_pos = pos if self.condition_on_desired else None self.condition_vel = vel if self.condition_on_desired else None @@ -221,4 +218,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.current_traj_steps = 0 self.plan_steps = 0 self.traj_gen.reset() + self.condition_vel = None + self.condition_pos = None return super(BlackBoxWrapper, self).reset()