calc episodic infos for some fancy envs
This commit is contained in:
parent
f37c8caaa4
commit
6f1837bda5
@ -6,7 +6,7 @@ from gym import spaces
|
||||
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.vec_env import VecEnv, DummyVecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.utils import obs_as_tensor
|
||||
|
||||
@ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass():
|
||||
actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
if 'episode_end' in infos[0]:
|
||||
for i in range(len(infos)):
|
||||
if infos[i]['episode_end'] and 'episode' not in infos:
|
||||
infos[i]['episode'] = {'r': rewards[i]}
|
||||
if len(infos) and 'r' not in infos[0]:
|
||||
for i in range(len(infos)):
|
||||
if 'r' not in infos[i]:
|
||||
@ -286,8 +290,6 @@ class GaussianRolloutCollectorAuxclass():
|
||||
:param infos: List of additional information about the transition.
|
||||
:param dones: Termination signals
|
||||
"""
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
if dones is None:
|
||||
dones = np.array([False] * len(infos))
|
||||
for idx, info in enumerate(infos):
|
||||
|
@ -399,6 +399,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
||||
self.logger.record("train/loss", loss.item())
|
||||
self.logger.record("train/explained_variance", explained_var)
|
||||
self.logger.record("train/ssf", self.sde_sample_freq)
|
||||
if hasattr(self.policy, "log_std"):
|
||||
self.logger.record(
|
||||
"train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
@ -374,6 +374,7 @@ class SAC(OffPolicyAlgorithm):
|
||||
self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/action_loss", np.mean(action_losses))
|
||||
self.logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||
self.logger.record("train/ssf", self.sde_sample_freq)
|
||||
self.logger.record("train/actor_loss", np.mean(actor_losses))
|
||||
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
if len(ent_coef_losses) > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user