diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index 49b5270..b4f6c6d 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type, Union, NamedTuple +from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator import numpy as np import torch as th @@ -13,6 +13,9 @@ from stable_baselines3.common.utils import obs_as_tensor from ..misc.distTools import get_mean_and_chol from ..distributions.distributions import Strength, UniversalGaussianDistribution +from stable_baselines3.common.vec_env import VecNormalize + + # TRL requires the origina mean and covariance from the policy when the datapoint was created. # GaussianRolloutBuffer extends the RolloutBuffer by these two fields @@ -54,7 +57,7 @@ class GaussianRolloutBuffer(RolloutBuffer): def reset(self) -> None: self.means = np.zeros( - (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) + (self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.chols = np.zeros( (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) super().reset() @@ -98,12 +101,43 @@ class GaussianRolloutBuffer(RolloutBuffer): self.episode_starts[self.pos] = np.array(episode_start).copy() self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() - self.means[self.pos] = mean.clone().cpu().numpy() - self.chols[self.pos] = chol.clone().cpu().numpy() + self.means[self.pos] = np.array(mean).copy() + self.chols[self.pos] = np.array(chol).copy() self.pos += 1 if self.pos == self.buffer_size: self.full = True + def get(self, batch_size: Optional[int] = None) -> Generator[GaussianRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + _tensor_names = [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "means", + "chols" + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten( + self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx: start_idx + batch_size]) + start_idx += batch_size + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples: data = ( self.observations[batch_inds], @@ -183,6 +217,8 @@ class GaussianRolloutCollectorAuxclass(): dist = self.policy.get_distribution(obs_tensor).distribution mean, chol = get_mean_and_chol(dist) actions = actions.cpu().numpy() + mean = mean.cpu().numpy() + chol = chol.cpu().numpy() # Rescale and perform action clipped_actions = actions