Force include reward in env infos (for vec env)

This commit is contained in:
Dominik Moritz Roth 2023-01-26 12:00:18 +01:00
parent 43cc749809
commit f37c8caaa4

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator, List
import numpy as np import numpy as np
import torch as th import torch as th
@ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass():
actions, self.action_space.low, self.action_space.high) actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions) new_obs, rewards, dones, infos = env.step(clipped_actions)
if len(infos) and 'r' not in infos[0]:
for i in range(len(infos)):
if 'r' not in infos[i]:
infos[i]['r'] = rewards[i]
self.num_timesteps += env.num_envs self.num_timesteps += env.num_envs
@ -274,3 +278,22 @@ class GaussianRolloutCollectorAuxclass():
callback.on_rollout_end() callback.on_rollout_end()
return True return True
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
"""
Retrieve reward, episode length, episode success and update the buffer
if using Monitor wrapper or a GoalEnv.
:param infos: List of additional information about the transition.
:param dones: Termination signals
"""
import pdb
pdb.set_trace()
if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
maybe_ep_info = info.get("episode")
maybe_is_success = info.get("is_success")
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)