diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 69640f8..f8008a6 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -366,7 +366,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): "train/std", th.mean(self.policy.chol).mean().item()) else: self.logger.record( - "train/std", th.mean(th.diagonal(self.policy.chol, dim1=-2, dim2=-1)).mean().item()) + "train/std", th.mean(th.sqrt(th.diagonal(self.policy.chol.T @ self.policy.chol, dim1=-2, dim2=-1))).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")