calc episodic infos for some fancy envs
This commit is contained in:
		
							parent
							
								
									f37c8caaa4
								
							
						
					
					
						commit
						6f1837bda5
					
				| @ -6,7 +6,7 @@ from gym import spaces | ||||
| 
 | ||||
| from stable_baselines3.common.buffers import RolloutBuffer | ||||
| from stable_baselines3.common.vec_env import VecNormalize | ||||
| from stable_baselines3.common.vec_env import VecEnv | ||||
| from stable_baselines3.common.vec_env import VecEnv, DummyVecEnv | ||||
| from stable_baselines3.common.callbacks import BaseCallback | ||||
| from stable_baselines3.common.utils import obs_as_tensor | ||||
| 
 | ||||
| @ -228,6 +228,10 @@ class GaussianRolloutCollectorAuxclass(): | ||||
|                     actions, self.action_space.low, self.action_space.high) | ||||
| 
 | ||||
|             new_obs, rewards, dones, infos = env.step(clipped_actions) | ||||
|             if 'episode_end' in infos[0]: | ||||
|                 for i in range(len(infos)): | ||||
|                     if infos[i]['episode_end'] and 'episode' not in infos: | ||||
|                         infos[i]['episode'] = {'r': rewards[i]} | ||||
|             if len(infos) and 'r' not in infos[0]: | ||||
|                 for i in range(len(infos)): | ||||
|                     if 'r' not in infos[i]: | ||||
| @ -286,8 +290,6 @@ class GaussianRolloutCollectorAuxclass(): | ||||
|         :param infos: List of additional information about the transition. | ||||
|         :param dones: Termination signals | ||||
|         """ | ||||
|         import pdb | ||||
|         pdb.set_trace() | ||||
|         if dones is None: | ||||
|             dones = np.array([False] * len(infos)) | ||||
|         for idx, info in enumerate(infos): | ||||
|  | ||||
| @ -399,6 +399,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): | ||||
|         self.logger.record("train/clip_fraction", np.mean(clip_fractions)) | ||||
|         self.logger.record("train/loss", loss.item()) | ||||
|         self.logger.record("train/explained_variance", explained_var) | ||||
|         self.logger.record("train/ssf", self.sde_sample_freq) | ||||
|         if hasattr(self.policy, "log_std"): | ||||
|             self.logger.record( | ||||
|                 "train/std", th.exp(self.policy.log_std).mean().item()) | ||||
|  | ||||
| @ -374,6 +374,7 @@ class SAC(OffPolicyAlgorithm): | ||||
|                            self._n_updates, exclude="tensorboard") | ||||
|         self.logger.record("train/action_loss", np.mean(action_losses)) | ||||
|         self.logger.record("train/ent_coef", np.mean(ent_coefs)) | ||||
|         self.logger.record("train/ssf", self.sde_sample_freq) | ||||
|         self.logger.record("train/actor_loss", np.mean(actor_losses)) | ||||
|         self.logger.record("train/critic_loss", np.mean(critic_losses)) | ||||
|         if len(ent_coef_losses) > 0: | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user