From 91f64c10d70ae9c9c896f88a706020a9e9760ed1 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 1 Jul 2022 20:02:09 +0200 Subject: [PATCH] Fixed bug with initialization of buffer --- metastable_baselines/misc/rollout_buffer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index f21ee39..c8ed57a 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -34,17 +34,18 @@ class GaussianRolloutBuffer(RolloutBuffer): n_envs: int = 1, cov_shape=None, ): - - super().__init__(buffer_size, observation_space, action_space, - device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) self.means, self.stds = None, None # TODO: Correct shape for full cov matrix # self.action_space.shape + self.action_space.shape if cov_shape == None: - cov_shape = self.action_space.shape + cov_shape = action_space.shape self.cov_shape = cov_shape + # It is ugly, but necessary to put this at the bottom of the init... + super().__init__(buffer_size, observation_space, action_space, + device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) + def reset(self) -> None: self.means = np.zeros( (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)