Force include reward in env infos (for vec env)
This commit is contained in:
parent
43cc749809
commit
f37c8caaa4
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator
|
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
@ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
actions, self.action_space.low, self.action_space.high)
|
actions, self.action_space.low, self.action_space.high)
|
||||||
|
|
||||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||||
|
if len(infos) and 'r' not in infos[0]:
|
||||||
|
for i in range(len(infos)):
|
||||||
|
if 'r' not in infos[i]:
|
||||||
|
infos[i]['r'] = rewards[i]
|
||||||
|
|
||||||
self.num_timesteps += env.num_envs
|
self.num_timesteps += env.num_envs
|
||||||
|
|
||||||
@ -274,3 +278,22 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
callback.on_rollout_end()
|
callback.on_rollout_end()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||||
|
"""
|
||||||
|
Retrieve reward, episode length, episode success and update the buffer
|
||||||
|
if using Monitor wrapper or a GoalEnv.
|
||||||
|
:param infos: List of additional information about the transition.
|
||||||
|
:param dones: Termination signals
|
||||||
|
"""
|
||||||
|
import pdb
|
||||||
|
pdb.set_trace()
|
||||||
|
if dones is None:
|
||||||
|
dones = np.array([False] * len(infos))
|
||||||
|
for idx, info in enumerate(infos):
|
||||||
|
maybe_ep_info = info.get("episode")
|
||||||
|
maybe_is_success = info.get("is_success")
|
||||||
|
if maybe_ep_info is not None:
|
||||||
|
self.ep_info_buffer.extend([maybe_ep_info])
|
||||||
|
if maybe_is_success is not None and dones[idx]:
|
||||||
|
self.ep_success_buffer.append(maybe_is_success)
|
||||||
|
Loading…
Reference in New Issue
Block a user