diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index a5eacfd..69640f8 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -361,8 +361,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.logger.record( "train/std", th.exp(self.policy.log_std).mean().item()) if hasattr(self.policy, "chol"): - self.logger.record( - "train/std", th.mean(th.diagonal(self.policy.chol, dim1=-2, dim2=-1)).mean().item()) + if len(self.policy.chol.shape) == 1: + self.logger.record( + "train/std", th.mean(self.policy.chol).mean().item()) + else: + self.logger.record( + "train/std", th.mean(th.diagonal(self.policy.chol, dim1=-2, dim2=-1)).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")