Update BB wrapper to follow new spec for render_kwargs
This commit is contained in:
		
							parent
							
								
									c985f2c415
								
							
						
					
					
						commit
						5db73f90c4
					
				| @ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): | ||||
|         self.observation_space = self._get_observation_space() | ||||
| 
 | ||||
|         # rendering | ||||
|         self.render_kwargs = {} | ||||
|         self.do_render = False | ||||
|         self.verbose = verbose | ||||
| 
 | ||||
|         # condition value | ||||
| @ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): | ||||
|                                     dtype=self.env.observation_space.dtype) | ||||
| 
 | ||||
|         infos = dict() | ||||
|         done = False | ||||
|         terminated, truncated = False, False | ||||
| 
 | ||||
|         if not traj_is_valid: | ||||
|             obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity, | ||||
| @ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): | ||||
|                 elems[t] = v | ||||
|                 infos[k] = elems | ||||
| 
 | ||||
|             if self.render_kwargs: | ||||
|                 self.env.render(**self.render_kwargs) | ||||
|             if self.do_render: | ||||
|                 self.env.render() | ||||
| 
 | ||||
| 
 | ||||
|             if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times): | ||||
| 
 | ||||
| @ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): | ||||
|         trajectory_return = self.reward_aggregation(rewards[:t + 1]) | ||||
|         return self.observation(obs), trajectory_return, terminated, truncated, infos | ||||
| 
 | ||||
|     def render(self, **kwargs): | ||||
|         """Only set render options here, such that they can be used during the rollout. | ||||
|         This only needs to be called once""" | ||||
|         self.render_kwargs = kwargs | ||||
|     def render(self): | ||||
|         self.do_render = True | ||||
| 
 | ||||
|     def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ | ||||
|             -> Tuple[ObsType, Dict[str, Any]]: | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user