diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index f8008a6..7ce1961 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -360,13 +360,18 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): if hasattr(self.policy, "log_std"): self.logger.record( "train/std", th.exp(self.policy.log_std).mean().item()) - if hasattr(self.policy, "chol"): + elif hasattr(self.policy, "chol"): if len(self.policy.chol.shape) == 1: self.logger.record( "train/std", th.mean(self.policy.chol).mean().item()) else: + if len(self.policy.chol.shape) == 2: + chol = self.policy.chol + else: + # TODO: Maybe use a broader sample? + chol = self.policy.chol[0] self.logger.record( - "train/std", th.mean(th.sqrt(th.diagonal(self.policy.chol.T @ self.policy.chol, dim1=-2, dim2=-1))).mean().item()) + "train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")