From f37c8caaa424d86e058b95f3ac8e99154fd9de51 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 26 Jan 2023 12:00:18 +0100 Subject: [PATCH] Force include reward in env infos (for vec env) --- metastable_baselines/misc/rollout_buffer.py | 25 ++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index b4f6c6d..e7408d7 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator +from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator, List import numpy as np import torch as th @@ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass(): actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) + if len(infos) and 'r' not in infos[0]: + for i in range(len(infos)): + if 'r' not in infos[i]: + infos[i]['r'] = rewards[i] self.num_timesteps += env.num_envs @@ -274,3 +278,22 @@ class GaussianRolloutCollectorAuxclass(): callback.on_rollout_end() return True + + def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: + """ + Retrieve reward, episode length, episode success and update the buffer + if using Monitor wrapper or a GoalEnv. + :param infos: List of additional information about the transition. + :param dones: Termination signals + """ + import pdb + pdb.set_trace() + if dones is None: + dones = np.array([False] * len(infos)) + for idx, info in enumerate(infos): + maybe_ep_info = info.get("episode") + maybe_is_success = info.get("is_success") + if maybe_ep_info is not None: + self.ep_info_buffer.extend([maybe_ep_info]) + if maybe_is_success is not None and dones[idx]: + self.ep_success_buffer.append(maybe_is_success)