diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index fbd49f2..f21ee39 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -7,6 +7,9 @@ from gym import spaces from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize +# TRL requires the origina mean and covariance from the policy when the datapoint was created. +# GaussianRolloutBuffer extends the RolloutBuffer by these two fields + class GaussianRolloutBufferSamples(NamedTuple): observations: th.Tensor @@ -29,18 +32,24 @@ class GaussianRolloutBuffer(RolloutBuffer): gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, + cov_shape=None, ): super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) self.means, self.stds = None, None + # TODO: Correct shape for full cov matrix + # self.action_space.shape + self.action_space.shape + + if cov_shape == None: + cov_shape = self.action_space.shape + self.cov_shape = cov_shape def reset(self) -> None: self.means = np.zeros( (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) self.stds = np.zeros( - # (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32) - (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) + (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) super().reset() def add(