394 lines
16 KiB
Python
394 lines
16 KiB
Python
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, Generator, List, Optional, Tuple, Union, NamedTuple
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gymnasium import spaces
|
|
|
|
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
|
|
from stable_baselines3.common.type_aliases import (
|
|
DictRolloutBufferSamples,
|
|
RolloutBufferSamples,
|
|
TensorDict
|
|
)
|
|
from stable_baselines3.common.utils import get_device
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
|
|
try:
|
|
# Check memory used by replay buffer when possible
|
|
import psutil
|
|
except ImportError:
|
|
psutil = None
|
|
|
|
from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer
|
|
|
|
class BetterRolloutBufferSamples(NamedTuple):
|
|
observations: th.Tensor
|
|
actions: th.Tensor
|
|
old_values: th.Tensor
|
|
old_log_prob: th.Tensor
|
|
mean: th.Tensor
|
|
cov_decomp: th.Tensor
|
|
advantages: th.Tensor
|
|
returns: th.Tensor
|
|
|
|
|
|
class BetterDictRolloutBufferSamples(NamedTuple):
|
|
observations: TensorDict
|
|
actions: th.Tensor
|
|
old_values: th.Tensor
|
|
old_log_prob: th.Tensor
|
|
mean: th.Tensor
|
|
cov_decomp: th.Tensor
|
|
advantages: th.Tensor
|
|
returns: th.Tensor
|
|
|
|
class BetterRolloutBuffer(RolloutBuffer):
|
|
"""
|
|
Extended to also save the mean and cov decomp.
|
|
|
|
Rollout buffer used in on-policy algorithms like A2C/PPO.
|
|
It corresponds to ``buffer_size`` transitions collected
|
|
using the current policy.
|
|
This experience will be discarded after the policy update.
|
|
In order to use PPO objective, we also store the current value of each state
|
|
and the log probability of each taken action.
|
|
|
|
The term rollout here refers to the model-free notion and should not
|
|
be used with the concept of rollout used in model-based RL or planning.
|
|
Hence, it is only involved in policy and value function training but not action selection.
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device: PyTorch device
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
Equivalent to classic advantage when set to 1.
|
|
:param gamma: Discount factor
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
observations: np.ndarray
|
|
actions: np.ndarray
|
|
rewards: np.ndarray
|
|
advantages: np.ndarray
|
|
returns: np.ndarray
|
|
episode_starts: np.ndarray
|
|
log_probs: np.ndarray
|
|
values: np.ndarray
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "auto",
|
|
gae_lambda: float = 1,
|
|
gamma: float = 0.99,
|
|
n_envs: int = 1,
|
|
full_cov: bool = False,
|
|
):
|
|
self.full_cov = full_cov
|
|
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
|
self.gae_lambda = gae_lambda
|
|
self.gamma = gamma
|
|
self.generator_ready = False
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
|
|
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
|
if self.full_cov:
|
|
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim, self.action_dim), dtype=np.float32)
|
|
else:
|
|
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.generator_ready = False
|
|
super().reset()
|
|
|
|
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
|
|
"""
|
|
Post-processing step: compute the lambda-return (TD(lambda) estimate)
|
|
and GAE(lambda) advantage.
|
|
|
|
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
|
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
|
|
where R is the sum of discounted reward with value bootstrap
|
|
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
|
|
|
|
The TD(lambda) estimator has also two special cases:
|
|
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
|
|
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
|
|
|
|
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
|
|
|
|
:param last_values: state value estimation for the last step (one for each env)
|
|
:param dones: if the last step was a terminal step (one bool for each env).
|
|
"""
|
|
# Convert to numpy
|
|
last_values = last_values.clone().cpu().numpy().flatten()
|
|
|
|
last_gae_lam = 0
|
|
for step in reversed(range(self.buffer_size)):
|
|
if step == self.buffer_size - 1:
|
|
next_non_terminal = 1.0 - dones
|
|
next_values = last_values
|
|
else:
|
|
next_non_terminal = 1.0 - self.episode_starts[step + 1]
|
|
next_values = self.values[step + 1]
|
|
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
|
|
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
|
self.advantages[step] = last_gae_lam
|
|
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
|
|
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
|
|
self.returns = self.advantages + self.values
|
|
|
|
def add(
|
|
self,
|
|
obs: np.ndarray,
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
episode_start: np.ndarray,
|
|
value: th.Tensor,
|
|
log_prob: th.Tensor,
|
|
mean: th.Tensor,
|
|
cov_decomp: th.Tensor,
|
|
) -> None:
|
|
|
|
"""
|
|
:param obs: Observation
|
|
:param action: Action
|
|
:param reward:
|
|
:param episode_start: Start of episode signal.
|
|
:param value: estimated value of the current state
|
|
following the current policy.
|
|
:param log_prob: log probability of the action
|
|
following the current policy.
|
|
"""
|
|
if len(log_prob.shape) == 0:
|
|
# Reshape 0-d tensor to avoid error
|
|
log_prob = log_prob.reshape(-1, 1)
|
|
|
|
# Reshape needed when using multiple envs with discrete observations
|
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
|
if isinstance(self.observation_space, spaces.Discrete):
|
|
obs = obs.reshape((self.n_envs, *self.obs_shape))
|
|
|
|
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
|
action = action.reshape((self.n_envs, self.action_dim))
|
|
|
|
self.observations[self.pos] = np.array(obs)
|
|
self.actions[self.pos] = np.array(action)
|
|
self.rewards[self.pos] = np.array(reward)
|
|
self.episode_starts[self.pos] = np.array(episode_start)
|
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
|
self.means[self.pos] = mean.clone().cpu().numpy()
|
|
self.cov_decomps[self.pos] = cov_decomp.clone().cpu().numpy()
|
|
self.pos += 1
|
|
if self.pos == self.buffer_size:
|
|
self.full = True
|
|
|
|
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
|
|
assert self.full, ""
|
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
|
# Prepare the data
|
|
if not self.generator_ready:
|
|
_tensor_names = [
|
|
"observations",
|
|
"actions",
|
|
"values",
|
|
"log_probs",
|
|
"advantages",
|
|
"returns",
|
|
"means",
|
|
"cov_decomps"
|
|
]
|
|
|
|
for tensor in _tensor_names:
|
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
|
self.generator_ready = True
|
|
|
|
# Return everything, don't create minibatches
|
|
if batch_size is None:
|
|
batch_size = self.buffer_size * self.n_envs
|
|
|
|
start_idx = 0
|
|
while start_idx < self.buffer_size * self.n_envs:
|
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
|
start_idx += batch_size
|
|
|
|
def _get_samples(
|
|
self,
|
|
batch_inds: np.ndarray,
|
|
env: Optional[VecNormalize] = None,
|
|
) -> BetterRolloutBufferSamples:
|
|
data = (
|
|
self.observations[batch_inds],
|
|
self.actions[batch_inds],
|
|
self.values[batch_inds].flatten(),
|
|
self.log_probs[batch_inds].flatten(),
|
|
self.means[batch_inds],
|
|
self.cov_decomps[batch_inds],
|
|
self.advantages[batch_inds].flatten(),
|
|
self.returns[batch_inds].flatten(),
|
|
)
|
|
return BetterRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
|
|
|
|
|
class BetterDictRolloutBuffer(DictRolloutBuffer):
|
|
"""
|
|
Extended to also save the mean and cov decomp.
|
|
|
|
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
|
Extends the RolloutBuffer to use dictionary observations
|
|
|
|
It corresponds to ``buffer_size`` transitions collected
|
|
using the current policy.
|
|
This experience will be discarded after the policy update.
|
|
In order to use PPO objective, we also store the current value of each state
|
|
and the log probability of each taken action.
|
|
|
|
The term rollout here refers to the model-free notion and should not
|
|
be used with the concept of rollout used in model-based RL or planning.
|
|
Hence, it is only involved in policy and value function training but not action selection.
|
|
|
|
:param buffer_size: Max number of element in the buffer
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param device: PyTorch device
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
Equivalent to Monte-Carlo advantage estimate when set to 1.
|
|
:param gamma: Discount factor
|
|
:param n_envs: Number of parallel environments
|
|
"""
|
|
|
|
observation_space: spaces.Dict
|
|
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
|
|
observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Dict,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "auto",
|
|
gae_lambda: float = 1,
|
|
gamma: float = 0.99,
|
|
n_envs: int = 1,
|
|
):
|
|
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
|
|
|
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
|
|
|
self.gae_lambda = gae_lambda
|
|
self.gamma = gamma
|
|
|
|
self.generator_ready = False
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
self.observations = {}
|
|
for key, obs_input_shape in self.obs_shape.items():
|
|
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
|
|
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
|
self.generator_ready = False
|
|
super(RolloutBuffer, self).reset()
|
|
|
|
def add( # type: ignore[override]
|
|
self,
|
|
obs: Dict[str, np.ndarray],
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
episode_start: np.ndarray,
|
|
value: th.Tensor,
|
|
log_prob: th.Tensor,
|
|
) -> None:
|
|
"""
|
|
:param obs: Observation
|
|
:param action: Action
|
|
:param reward:
|
|
:param episode_start: Start of episode signal.
|
|
:param value: estimated value of the current state
|
|
following the current policy.
|
|
:param log_prob: log probability of the action
|
|
following the current policy.
|
|
"""
|
|
if len(log_prob.shape) == 0:
|
|
# Reshape 0-d tensor to avoid error
|
|
log_prob = log_prob.reshape(-1, 1)
|
|
|
|
for key in self.observations.keys():
|
|
obs_ = np.array(obs[key])
|
|
# Reshape needed when using multiple envs with discrete observations
|
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
|
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
|
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
|
|
self.observations[key][self.pos] = obs_
|
|
|
|
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
|
action = action.reshape((self.n_envs, self.action_dim))
|
|
|
|
self.actions[self.pos] = np.array(action)
|
|
self.rewards[self.pos] = np.array(reward)
|
|
self.episode_starts[self.pos] = np.array(episode_start)
|
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
|
self.pos += 1
|
|
if self.pos == self.buffer_size:
|
|
self.full = True
|
|
|
|
def get( # type: ignore[override]
|
|
self,
|
|
batch_size: Optional[int] = None,
|
|
) -> Generator[DictRolloutBufferSamples, None, None]:
|
|
assert self.full, ""
|
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
|
# Prepare the data
|
|
if not self.generator_ready:
|
|
for key, obs in self.observations.items():
|
|
self.observations[key] = self.swap_and_flatten(obs)
|
|
|
|
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
|
|
|
|
for tensor in _tensor_names:
|
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
|
self.generator_ready = True
|
|
|
|
# Return everything, don't create minibatches
|
|
if batch_size is None:
|
|
batch_size = self.buffer_size * self.n_envs
|
|
|
|
start_idx = 0
|
|
while start_idx < self.buffer_size * self.n_envs:
|
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
|
start_idx += batch_size
|
|
|
|
def _get_samples( # type: ignore[override]
|
|
self,
|
|
batch_inds: np.ndarray,
|
|
env: Optional[VecNormalize] = None,
|
|
) -> BetterDictRolloutBufferSamples:
|
|
return BetterDictRolloutBufferSamples(
|
|
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
|
actions=self.to_torch(self.actions[batch_inds]),
|
|
old_values=self.to_torch(self.values[batch_inds].flatten()),
|
|
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
|
|
mean=self.to_torch(self.means[batch_inds].flatten()),
|
|
cov_decomp=self.to_torch(self.cov_decomps[batch_inds].flatten()),
|
|
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
|
|
returns=self.to_torch(self.returns[batch_inds].flatten()),
|
|
) |