From f184b88f1996bb556f760faa2054aa4f34e7eb6d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 15 Jul 2022 18:46:42 +0200 Subject: [PATCH] Allow std logging for full and diagonal cov policies --- metastable_baselines/ppo/ppo.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index a5eacfd..69640f8 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -361,8 +361,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.logger.record( "train/std", th.exp(self.policy.log_std).mean().item()) if hasattr(self.policy, "chol"): - self.logger.record( - "train/std", th.mean(th.diagonal(self.policy.chol, dim1=-2, dim2=-1)).mean().item()) + if len(self.policy.chol.shape) == 1: + self.logger.record( + "train/std", th.mean(self.policy.chol).mean().item()) + else: + self.logger.record( + "train/std", th.mean(th.diagonal(self.policy.chol, dim1=-2, dim2=-1)).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")