Fixed bug with initialization of buffer
This commit is contained in:
parent
0dc9edf112
commit
91f64c10d7
@ -34,17 +34,18 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
||||
n_envs: int = 1,
|
||||
cov_shape=None,
|
||||
):
|
||||
|
||||
super().__init__(buffer_size, observation_space, action_space,
|
||||
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
|
||||
self.means, self.stds = None, None
|
||||
# TODO: Correct shape for full cov matrix
|
||||
# self.action_space.shape + self.action_space.shape
|
||||
|
||||
if cov_shape == None:
|
||||
cov_shape = self.action_space.shape
|
||||
cov_shape = action_space.shape
|
||||
self.cov_shape = cov_shape
|
||||
|
||||
# It is ugly, but necessary to put this at the bottom of the init...
|
||||
super().__init__(buffer_size, observation_space, action_space,
|
||||
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.means = np.zeros(
|
||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user