Fixed bug in RolloutBuffer when using parallel envs
This commit is contained in:
parent
02e4ed1510
commit
afec4e709c
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Generator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
@ -13,6 +13,9 @@ from stable_baselines3.common.utils import obs_as_tensor
|
|||||||
from ..misc.distTools import get_mean_and_chol
|
from ..misc.distTools import get_mean_and_chol
|
||||||
from ..distributions.distributions import Strength, UniversalGaussianDistribution
|
from ..distributions.distributions import Strength, UniversalGaussianDistribution
|
||||||
|
|
||||||
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
|
|
||||||
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
||||||
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
||||||
|
|
||||||
@ -54,7 +57,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.means = np.zeros(
|
self.means = np.zeros(
|
||||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
(self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||||
self.chols = np.zeros(
|
self.chols = np.zeros(
|
||||||
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
||||||
super().reset()
|
super().reset()
|
||||||
@ -98,12 +101,43 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
||||||
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||||
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||||
self.means[self.pos] = mean.clone().cpu().numpy()
|
self.means[self.pos] = np.array(mean).copy()
|
||||||
self.chols[self.pos] = chol.clone().cpu().numpy()
|
self.chols[self.pos] = np.array(chol).copy()
|
||||||
self.pos += 1
|
self.pos += 1
|
||||||
if self.pos == self.buffer_size:
|
if self.pos == self.buffer_size:
|
||||||
self.full = True
|
self.full = True
|
||||||
|
|
||||||
|
def get(self, batch_size: Optional[int] = None) -> Generator[GaussianRolloutBufferSamples, None, None]:
|
||||||
|
assert self.full, ""
|
||||||
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
|
# Prepare the data
|
||||||
|
if not self.generator_ready:
|
||||||
|
|
||||||
|
_tensor_names = [
|
||||||
|
"observations",
|
||||||
|
"actions",
|
||||||
|
"values",
|
||||||
|
"log_probs",
|
||||||
|
"advantages",
|
||||||
|
"returns",
|
||||||
|
"means",
|
||||||
|
"chols"
|
||||||
|
]
|
||||||
|
|
||||||
|
for tensor in _tensor_names:
|
||||||
|
self.__dict__[tensor] = self.swap_and_flatten(
|
||||||
|
self.__dict__[tensor])
|
||||||
|
self.generator_ready = True
|
||||||
|
|
||||||
|
# Return everything, don't create minibatches
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.buffer_size * self.n_envs
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
while start_idx < self.buffer_size * self.n_envs:
|
||||||
|
yield self._get_samples(indices[start_idx: start_idx + batch_size])
|
||||||
|
start_idx += batch_size
|
||||||
|
|
||||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples:
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples:
|
||||||
data = (
|
data = (
|
||||||
self.observations[batch_inds],
|
self.observations[batch_inds],
|
||||||
@ -183,6 +217,8 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
dist = self.policy.get_distribution(obs_tensor).distribution
|
dist = self.policy.get_distribution(obs_tensor).distribution
|
||||||
mean, chol = get_mean_and_chol(dist)
|
mean, chol = get_mean_and_chol(dist)
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
mean = mean.cpu().numpy()
|
||||||
|
chol = chol.cpu().numpy()
|
||||||
|
|
||||||
# Rescale and perform action
|
# Rescale and perform action
|
||||||
clipped_actions = actions
|
clipped_actions = actions
|
||||||
|
Loading…
Reference in New Issue
Block a user