diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index b4f6c6d..e7408d7 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator +from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator, List import numpy as np import torch as th @@ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass(): actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) + if len(infos) and 'r' not in infos[0]: + for i in range(len(infos)): + if 'r' not in infos[i]: + infos[i]['r'] = rewards[i] self.num_timesteps += env.num_envs @@ -274,3 +278,22 @@ class GaussianRolloutCollectorAuxclass(): callback.on_rollout_end() return True + + def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: + """ + Retrieve reward, episode length, episode success and update the buffer + if using Monitor wrapper or a GoalEnv. + :param infos: List of additional information about the transition. + :param dones: Termination signals + """ + import pdb + pdb.set_trace() + if dones is None: + dones = np.array([False] * len(infos)) + for idx, info in enumerate(infos): + maybe_ep_info = info.get("episode") + maybe_is_success = info.get("is_success") + if maybe_ep_info is not None: + self.ep_info_buffer.extend([maybe_ep_info]) + if maybe_is_success is not None and dones[idx]: + self.ep_success_buffer.append(maybe_is_success)