Update BB wrapper to follow new spec for render_kwargs

This commit is contained in:
Dominik Moritz Roth 2023-10-23 12:25:53 +02:00
parent c985f2c415
commit 5db73f90c4

View File

@ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.observation_space = self._get_observation_space() self.observation_space = self._get_observation_space()
# rendering # rendering
self.render_kwargs = {} self.do_render = False
self.verbose = verbose self.verbose = verbose
# condition value # condition value
@ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
dtype=self.env.observation_space.dtype) dtype=self.env.observation_space.dtype)
infos = dict() infos = dict()
done = False terminated, truncated = False, False
if not traj_is_valid: if not traj_is_valid:
obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity, obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity,
@ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
elems[t] = v elems[t] = v
infos[k] = elems infos[k] = elems
if self.render_kwargs: if self.do_render:
self.env.render(**self.render_kwargs) self.env.render()
if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times): if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times):
@ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
trajectory_return = self.reward_aggregation(rewards[:t + 1]) trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, terminated, truncated, infos return self.observation(obs), trajectory_return, terminated, truncated, infos
def render(self, **kwargs): def render(self):
"""Only set render options here, such that they can be used during the rollout. self.do_render = True
This only needs to be called once"""
self.render_kwargs = kwargs
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
-> Tuple[ObsType, Dict[str, Any]]: -> Tuple[ObsType, Dict[str, Any]]: