Fixed bug in RolloutBuffer when using parallel envs
This commit is contained in:
		
							parent
							
								
									02e4ed1510
								
							
						
					
					
						commit
						afec4e709c
					
				| @ -1,4 +1,4 @@ | ||||
| from typing import Any, Dict, Optional, Type, Union, NamedTuple | ||||
| from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch as th | ||||
| @ -13,6 +13,9 @@ from stable_baselines3.common.utils import obs_as_tensor | ||||
| from ..misc.distTools import get_mean_and_chol | ||||
| from ..distributions.distributions import Strength, UniversalGaussianDistribution | ||||
| 
 | ||||
| from stable_baselines3.common.vec_env import VecNormalize | ||||
| 
 | ||||
| 
 | ||||
| # TRL requires the origina mean and covariance from the policy when the datapoint was created. | ||||
| # GaussianRolloutBuffer extends the RolloutBuffer by these two fields | ||||
| 
 | ||||
| @ -54,7 +57,7 @@ class GaussianRolloutBuffer(RolloutBuffer): | ||||
| 
 | ||||
|     def reset(self) -> None: | ||||
|         self.means = np.zeros( | ||||
|             (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) | ||||
|             (self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) | ||||
|         self.chols = np.zeros( | ||||
|             (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) | ||||
|         super().reset() | ||||
| @ -98,12 +101,43 @@ class GaussianRolloutBuffer(RolloutBuffer): | ||||
|         self.episode_starts[self.pos] = np.array(episode_start).copy() | ||||
|         self.values[self.pos] = value.clone().cpu().numpy().flatten() | ||||
|         self.log_probs[self.pos] = log_prob.clone().cpu().numpy() | ||||
|         self.means[self.pos] = mean.clone().cpu().numpy() | ||||
|         self.chols[self.pos] = chol.clone().cpu().numpy() | ||||
|         self.means[self.pos] = np.array(mean).copy() | ||||
|         self.chols[self.pos] = np.array(chol).copy() | ||||
|         self.pos += 1 | ||||
|         if self.pos == self.buffer_size: | ||||
|             self.full = True | ||||
| 
 | ||||
|     def get(self, batch_size: Optional[int] = None) -> Generator[GaussianRolloutBufferSamples, None, None]: | ||||
|         assert self.full, "" | ||||
|         indices = np.random.permutation(self.buffer_size * self.n_envs) | ||||
|         # Prepare the data | ||||
|         if not self.generator_ready: | ||||
| 
 | ||||
|             _tensor_names = [ | ||||
|                 "observations", | ||||
|                 "actions", | ||||
|                 "values", | ||||
|                 "log_probs", | ||||
|                 "advantages", | ||||
|                 "returns", | ||||
|                 "means", | ||||
|                 "chols" | ||||
|             ] | ||||
| 
 | ||||
|             for tensor in _tensor_names: | ||||
|                 self.__dict__[tensor] = self.swap_and_flatten( | ||||
|                     self.__dict__[tensor]) | ||||
|             self.generator_ready = True | ||||
| 
 | ||||
|         # Return everything, don't create minibatches | ||||
|         if batch_size is None: | ||||
|             batch_size = self.buffer_size * self.n_envs | ||||
| 
 | ||||
|         start_idx = 0 | ||||
|         while start_idx < self.buffer_size * self.n_envs: | ||||
|             yield self._get_samples(indices[start_idx: start_idx + batch_size]) | ||||
|             start_idx += batch_size | ||||
| 
 | ||||
|     def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples: | ||||
|         data = ( | ||||
|             self.observations[batch_inds], | ||||
| @ -183,6 +217,8 @@ class GaussianRolloutCollectorAuxclass(): | ||||
|                 dist = self.policy.get_distribution(obs_tensor).distribution | ||||
|                 mean, chol = get_mean_and_chol(dist) | ||||
|             actions = actions.cpu().numpy() | ||||
|             mean = mean.cpu().numpy() | ||||
|             chol = chol.cpu().numpy() | ||||
| 
 | ||||
|             # Rescale and perform action | ||||
|             clipped_actions = actions | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user