diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index f21ee39..c8ed57a 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -34,17 +34,18 @@ class GaussianRolloutBuffer(RolloutBuffer): n_envs: int = 1, cov_shape=None, ): - - super().__init__(buffer_size, observation_space, action_space, - device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) self.means, self.stds = None, None # TODO: Correct shape for full cov matrix # self.action_space.shape + self.action_space.shape if cov_shape == None: - cov_shape = self.action_space.shape + cov_shape = action_space.shape self.cov_shape = cov_shape + # It is ugly, but necessary to put this at the bottom of the init... + super().__init__(buffer_size, observation_space, action_space, + device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) + def reset(self) -> None: self.means = np.zeros( (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)