Update BB wrapper to follow new spec for render_kwargs
This commit is contained in:
parent
c985f2c415
commit
5db73f90c4
@ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.observation_space = self._get_observation_space()
|
||||
|
||||
# rendering
|
||||
self.render_kwargs = {}
|
||||
self.do_render = False
|
||||
self.verbose = verbose
|
||||
|
||||
# condition value
|
||||
@ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
dtype=self.env.observation_space.dtype)
|
||||
|
||||
infos = dict()
|
||||
done = False
|
||||
terminated, truncated = False, False
|
||||
|
||||
if not traj_is_valid:
|
||||
obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||
@ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
elems[t] = v
|
||||
infos[k] = elems
|
||||
|
||||
if self.render_kwargs:
|
||||
self.env.render(**self.render_kwargs)
|
||||
if self.do_render:
|
||||
self.env.render()
|
||||
|
||||
|
||||
if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times):
|
||||
|
||||
@ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||
return self.observation(obs), trajectory_return, terminated, truncated, infos
|
||||
|
||||
def render(self, **kwargs):
|
||||
"""Only set render options here, such that they can be used during the rollout.
|
||||
This only needs to be called once"""
|
||||
self.render_kwargs = kwargs
|
||||
def render(self):
|
||||
self.do_render = True
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
|
Loading…
Reference in New Issue
Block a user