From 3fa6de7e66f4cc76062d1c5748083e27ac811922 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 16 Jul 2022 15:28:16 +0200 Subject: [PATCH] Broader sampling of stds for logging with batched full covs --- metastable_baselines/ppo/ppo.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 7ce1961..54672f4 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -361,17 +361,16 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): self.logger.record( "train/std", th.exp(self.policy.log_std).mean().item()) elif hasattr(self.policy, "chol"): - if len(self.policy.chol.shape) == 1: + chol = self.policy.chol + if len(chol.shape) == 1: self.logger.record( - "train/std", th.mean(self.policy.chol).mean().item()) - else: - if len(self.policy.chol.shape) == 2: - chol = self.policy.chol - else: - # TODO: Maybe use a broader sample? - chol = self.policy.chol[0] + "train/std", th.mean(chol).mean().item()) + elif len(chol.shape) == 2: self.logger.record( "train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item()) + else: + self.logger.record( + "train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")