Update BB wrapper to follow new spec for render_kwargs

This commit is contained in:
Dominik Moritz Roth 2023-10-23 12:25:53 +02:00
parent c985f2c415
commit 5db73f90c4

View File

@ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.observation_space = self._get_observation_space()
# rendering
self.render_kwargs = {}
self.do_render = False
self.verbose = verbose
# condition value
@ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
dtype=self.env.observation_space.dtype)
infos = dict()
done = False
terminated, truncated = False, False
if not traj_is_valid:
obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity,
@ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
elems[t] = v
infos[k] = elems
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if self.do_render:
self.env.render()
if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times):
@ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, terminated, truncated, infos
def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout.
This only needs to be called once"""
self.render_kwargs = kwargs
def render(self):
self.do_render = True
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
-> Tuple[ObsType, Dict[str, Any]]: