diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index e7408d7..fd5b441 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -6,7 +6,7 @@ from gym import spaces from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import VecEnv, DummyVecEnv from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.utils import obs_as_tensor @@ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass(): actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) + if 'episode_end' in infos[0]: + for i in range(len(infos)): + if infos[i]['episode_end'] and 'episode' not in infos: + infos[i]['episode'] = {'r': rewards[i]} if len(infos) and 'r' not in infos[0]: for i in range(len(infos)): if 'r' not in infos[i]: @@ -286,8 +290,6 @@ class GaussianRolloutCollectorAuxclass(): :param infos: List of additional information about the transition. :param dones: Termination signals """ - import pdb - pdb.set_trace() if dones is None: dones = np.array([False] * len(infos)) for idx, info in enumerate(infos): diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 3caec81..bfc2844 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -399,6 +399,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/ssf", self.sde_sample_freq) if hasattr(self.policy, "log_std"): self.logger.record( "train/std", th.exp(self.policy.log_std).mean().item()) diff --git a/metastable_baselines/sac/sac.py b/metastable_baselines/sac/sac.py index 607c008..6b6dd5d 100644 --- a/metastable_baselines/sac/sac.py +++ b/metastable_baselines/sac/sac.py @@ -374,6 +374,7 @@ class SAC(OffPolicyAlgorithm): self._n_updates, exclude="tensorboard") self.logger.record("train/action_loss", np.mean(action_losses)) self.logger.record("train/ent_coef", np.mean(ent_coefs)) + self.logger.record("train/ssf", self.sde_sample_freq) self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: