diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index fd5b441..846d60a 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -228,14 +228,6 @@ class GaussianRolloutCollectorAuxclass(): actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) - if 'episode_end' in infos[0]: - for i in range(len(infos)): - if infos[i]['episode_end'] and 'episode' not in infos: - infos[i]['episode'] = {'r': rewards[i]} - if len(infos) and 'r' not in infos[0]: - for i in range(len(infos)): - if 'r' not in infos[i]: - infos[i]['r'] = rewards[i] self.num_timesteps += env.num_envs @@ -282,20 +274,3 @@ class GaussianRolloutCollectorAuxclass(): callback.on_rollout_end() return True - - def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: - """ - Retrieve reward, episode length, episode success and update the buffer - if using Monitor wrapper or a GoalEnv. - :param infos: List of additional information about the transition. - :param dones: Termination signals - """ - if dones is None: - dones = np.array([False] * len(infos)) - for idx, info in enumerate(infos): - maybe_ep_info = info.get("episode") - maybe_is_success = info.get("is_success") - if maybe_ep_info is not None: - self.ep_info_buffer.extend([maybe_ep_info]) - if maybe_is_success is not None and dones[idx]: - self.ep_success_buffer.append(maybe_is_success)