Force include reward in env infos (for vec env)
This commit is contained in:
parent
43cc749809
commit
f37c8caaa4
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator
|
||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator, List
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
@ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass():
|
||||
actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
if len(infos) and 'r' not in infos[0]:
|
||||
for i in range(len(infos)):
|
||||
if 'r' not in infos[i]:
|
||||
infos[i]['r'] = rewards[i]
|
||||
|
||||
self.num_timesteps += env.num_envs
|
||||
|
||||
@ -274,3 +278,22 @@ class GaussianRolloutCollectorAuxclass():
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||
"""
|
||||
Retrieve reward, episode length, episode success and update the buffer
|
||||
if using Monitor wrapper or a GoalEnv.
|
||||
:param infos: List of additional information about the transition.
|
||||
:param dones: Termination signals
|
||||
"""
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
if dones is None:
|
||||
dones = np.array([False] * len(infos))
|
||||
for idx, info in enumerate(infos):
|
||||
maybe_ep_info = info.get("episode")
|
||||
maybe_is_success = info.get("is_success")
|
||||
if maybe_ep_info is not None:
|
||||
self.ep_info_buffer.extend([maybe_ep_info])
|
||||
if maybe_is_success is not None and dones[idx]:
|
||||
self.ep_success_buffer.append(maybe_is_success)
|
||||
|
Loading…
Reference in New Issue
Block a user