diff --git a/metastable_baselines2/common/buffers.py b/metastable_baselines2/common/buffers.py index 94e807c..18dbb41 100644 --- a/metastable_baselines2/common/buffers.py +++ b/metastable_baselines2/common/buffers.py @@ -202,6 +202,8 @@ class BetterRolloutBuffer(RolloutBuffer): "log_probs", "advantages", "returns", + "means", + "cov_decomps" ] for tensor in _tensor_names: @@ -227,8 +229,8 @@ class BetterRolloutBuffer(RolloutBuffer): self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), - np.squeeze(self.means[batch_inds], axis=1), - np.squeeze(self.cov_decomps[batch_inds], axis=1), + self.means[batch_inds], + self.cov_decomps[batch_inds], self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), )