From b2384e183cb6806f4fef432873f3593281f19e31 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 9 Mar 2024 12:33:20 +0100 Subject: [PATCH] Add support for VecEnvs --- metastable_baselines2/common/buffers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metastable_baselines2/common/buffers.py b/metastable_baselines2/common/buffers.py index 94e807c..18dbb41 100644 --- a/metastable_baselines2/common/buffers.py +++ b/metastable_baselines2/common/buffers.py @@ -202,6 +202,8 @@ class BetterRolloutBuffer(RolloutBuffer): "log_probs", "advantages", "returns", + "means", + "cov_decomps" ] for tensor in _tensor_names: @@ -227,8 +229,8 @@ class BetterRolloutBuffer(RolloutBuffer): self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), - np.squeeze(self.means[batch_inds], axis=1), - np.squeeze(self.cov_decomps[batch_inds], axis=1), + self.means[batch_inds], + self.cov_decomps[batch_inds], self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), )