diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 6da24c7..c963f0c 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.observation_space = self._get_observation_space() # rendering - self.render_kwargs = {} + self.do_render = False self.verbose = verbose # condition value @@ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): dtype=self.env.observation_space.dtype) infos = dict() - done = False + terminated, truncated = False, False if not traj_is_valid: obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity, @@ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): elems[t] = v infos[k] = elems - if self.render_kwargs: - self.env.render(**self.render_kwargs) + if self.do_render: + self.env.render() + if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times): @@ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): trajectory_return = self.reward_aggregation(rewards[:t + 1]) return self.observation(obs), trajectory_return, terminated, truncated, infos - def render(self, **kwargs): - """Only set render options here, such that they can be used during the rollout. - This only needs to be called once""" - self.render_kwargs = kwargs + def render(self): + self.do_render = True def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ -> Tuple[ObsType, Dict[str, Any]]: