Update BB wrapper to follow new spec for render_kwargs
This commit is contained in:
parent
c985f2c415
commit
5db73f90c4
@ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.observation_space = self._get_observation_space()
|
self.observation_space = self._get_observation_space()
|
||||||
|
|
||||||
# rendering
|
# rendering
|
||||||
self.render_kwargs = {}
|
self.do_render = False
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# condition value
|
# condition value
|
||||||
@ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
dtype=self.env.observation_space.dtype)
|
dtype=self.env.observation_space.dtype)
|
||||||
|
|
||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
terminated, truncated = False, False
|
||||||
|
|
||||||
if not traj_is_valid:
|
if not traj_is_valid:
|
||||||
obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity,
|
obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||||
@ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
elems[t] = v
|
elems[t] = v
|
||||||
infos[k] = elems
|
infos[k] = elems
|
||||||
|
|
||||||
if self.render_kwargs:
|
if self.do_render:
|
||||||
self.env.render(**self.render_kwargs)
|
self.env.render()
|
||||||
|
|
||||||
|
|
||||||
if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times):
|
if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times):
|
||||||
|
|
||||||
@ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||||
return self.observation(obs), trajectory_return, terminated, truncated, infos
|
return self.observation(obs), trajectory_return, terminated, truncated, infos
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self):
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
self.do_render = True
|
||||||
This only needs to be called once"""
|
|
||||||
self.render_kwargs = kwargs
|
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||||
-> Tuple[ObsType, Dict[str, Any]]:
|
-> Tuple[ObsType, Dict[str, Any]]:
|
||||||
|
Loading…
Reference in New Issue
Block a user