Fixed bug in RolloutBuffer when using parallel envs
This commit is contained in:
		
							parent
							
								
									02e4ed1510
								
							
						
					
					
						commit
						afec4e709c
					
				| @ -1,4 +1,4 @@ | |||||||
| from typing import Any, Dict, Optional, Type, Union, NamedTuple | from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch as th | import torch as th | ||||||
| @ -13,6 +13,9 @@ from stable_baselines3.common.utils import obs_as_tensor | |||||||
| from ..misc.distTools import get_mean_and_chol | from ..misc.distTools import get_mean_and_chol | ||||||
| from ..distributions.distributions import Strength, UniversalGaussianDistribution | from ..distributions.distributions import Strength, UniversalGaussianDistribution | ||||||
| 
 | 
 | ||||||
|  | from stable_baselines3.common.vec_env import VecNormalize | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # TRL requires the origina mean and covariance from the policy when the datapoint was created. | # TRL requires the origina mean and covariance from the policy when the datapoint was created. | ||||||
| # GaussianRolloutBuffer extends the RolloutBuffer by these two fields | # GaussianRolloutBuffer extends the RolloutBuffer by these two fields | ||||||
| 
 | 
 | ||||||
| @ -54,7 +57,7 @@ class GaussianRolloutBuffer(RolloutBuffer): | |||||||
| 
 | 
 | ||||||
|     def reset(self) -> None: |     def reset(self) -> None: | ||||||
|         self.means = np.zeros( |         self.means = np.zeros( | ||||||
|             (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) |             (self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) | ||||||
|         self.chols = np.zeros( |         self.chols = np.zeros( | ||||||
|             (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) |             (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) | ||||||
|         super().reset() |         super().reset() | ||||||
| @ -98,12 +101,43 @@ class GaussianRolloutBuffer(RolloutBuffer): | |||||||
|         self.episode_starts[self.pos] = np.array(episode_start).copy() |         self.episode_starts[self.pos] = np.array(episode_start).copy() | ||||||
|         self.values[self.pos] = value.clone().cpu().numpy().flatten() |         self.values[self.pos] = value.clone().cpu().numpy().flatten() | ||||||
|         self.log_probs[self.pos] = log_prob.clone().cpu().numpy() |         self.log_probs[self.pos] = log_prob.clone().cpu().numpy() | ||||||
|         self.means[self.pos] = mean.clone().cpu().numpy() |         self.means[self.pos] = np.array(mean).copy() | ||||||
|         self.chols[self.pos] = chol.clone().cpu().numpy() |         self.chols[self.pos] = np.array(chol).copy() | ||||||
|         self.pos += 1 |         self.pos += 1 | ||||||
|         if self.pos == self.buffer_size: |         if self.pos == self.buffer_size: | ||||||
|             self.full = True |             self.full = True | ||||||
| 
 | 
 | ||||||
|  |     def get(self, batch_size: Optional[int] = None) -> Generator[GaussianRolloutBufferSamples, None, None]: | ||||||
|  |         assert self.full, "" | ||||||
|  |         indices = np.random.permutation(self.buffer_size * self.n_envs) | ||||||
|  |         # Prepare the data | ||||||
|  |         if not self.generator_ready: | ||||||
|  | 
 | ||||||
|  |             _tensor_names = [ | ||||||
|  |                 "observations", | ||||||
|  |                 "actions", | ||||||
|  |                 "values", | ||||||
|  |                 "log_probs", | ||||||
|  |                 "advantages", | ||||||
|  |                 "returns", | ||||||
|  |                 "means", | ||||||
|  |                 "chols" | ||||||
|  |             ] | ||||||
|  | 
 | ||||||
|  |             for tensor in _tensor_names: | ||||||
|  |                 self.__dict__[tensor] = self.swap_and_flatten( | ||||||
|  |                     self.__dict__[tensor]) | ||||||
|  |             self.generator_ready = True | ||||||
|  | 
 | ||||||
|  |         # Return everything, don't create minibatches | ||||||
|  |         if batch_size is None: | ||||||
|  |             batch_size = self.buffer_size * self.n_envs | ||||||
|  | 
 | ||||||
|  |         start_idx = 0 | ||||||
|  |         while start_idx < self.buffer_size * self.n_envs: | ||||||
|  |             yield self._get_samples(indices[start_idx: start_idx + batch_size]) | ||||||
|  |             start_idx += batch_size | ||||||
|  | 
 | ||||||
|     def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples: |     def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples: | ||||||
|         data = ( |         data = ( | ||||||
|             self.observations[batch_inds], |             self.observations[batch_inds], | ||||||
| @ -183,6 +217,8 @@ class GaussianRolloutCollectorAuxclass(): | |||||||
|                 dist = self.policy.get_distribution(obs_tensor).distribution |                 dist = self.policy.get_distribution(obs_tensor).distribution | ||||||
|                 mean, chol = get_mean_and_chol(dist) |                 mean, chol = get_mean_and_chol(dist) | ||||||
|             actions = actions.cpu().numpy() |             actions = actions.cpu().numpy() | ||||||
|  |             mean = mean.cpu().numpy() | ||||||
|  |             chol = chol.cpu().numpy() | ||||||
| 
 | 
 | ||||||
|             # Rescale and perform action |             # Rescale and perform action | ||||||
|             clipped_actions = actions |             clipped_actions = actions | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user