Implement Buffer that stores Distribution params
This commit is contained in:
parent
5ed5d32083
commit
1fa66611a3
386
metastable_baselines2/common/buffers.py
Normal file
386
metastable_baselines2/common/buffers.py
Normal file
@ -0,0 +1,386 @@
|
|||||||
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, Tuple, Union, NamedTuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
|
||||||
|
from stable_baselines3.common.type_aliases import (
|
||||||
|
DictRolloutBufferSamples,
|
||||||
|
RolloutBufferSamples,
|
||||||
|
TensorDict
|
||||||
|
)
|
||||||
|
from stable_baselines3.common.utils import get_device
|
||||||
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check memory used by replay buffer when possible
|
||||||
|
import psutil
|
||||||
|
except ImportError:
|
||||||
|
psutil = None
|
||||||
|
|
||||||
|
from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer
|
||||||
|
|
||||||
|
class BetterRolloutBufferSamples(NamedTuple):
|
||||||
|
observations: th.Tensor
|
||||||
|
actions: th.Tensor
|
||||||
|
old_values: th.Tensor
|
||||||
|
old_log_prob: th.Tensor
|
||||||
|
mean: th.Tensor
|
||||||
|
cov_Decomp: th.Tensor
|
||||||
|
advantages: th.Tensor
|
||||||
|
returns: th.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BetterDictRolloutBufferSamples(NamedTuple):
|
||||||
|
observations: TensorDict
|
||||||
|
actions: th.Tensor
|
||||||
|
old_values: th.Tensor
|
||||||
|
old_log_prob: th.Tensor
|
||||||
|
mean: th.Tensor
|
||||||
|
cov_Decomp: th.Tensor
|
||||||
|
advantages: th.Tensor
|
||||||
|
returns: th.Tensor
|
||||||
|
|
||||||
|
class BetterRolloutBuffer(RolloutBuffer):
|
||||||
|
"""
|
||||||
|
Extended to also save the mean and cov decomp.
|
||||||
|
|
||||||
|
Rollout buffer used in on-policy algorithms like A2C/PPO.
|
||||||
|
It corresponds to ``buffer_size`` transitions collected
|
||||||
|
using the current policy.
|
||||||
|
This experience will be discarded after the policy update.
|
||||||
|
In order to use PPO objective, we also store the current value of each state
|
||||||
|
and the log probability of each taken action.
|
||||||
|
|
||||||
|
The term rollout here refers to the model-free notion and should not
|
||||||
|
be used with the concept of rollout used in model-based RL or planning.
|
||||||
|
Hence, it is only involved in policy and value function training but not action selection.
|
||||||
|
|
||||||
|
:param buffer_size: Max number of element in the buffer
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param device: PyTorch device
|
||||||
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||||
|
Equivalent to classic advantage when set to 1.
|
||||||
|
:param gamma: Discount factor
|
||||||
|
:param n_envs: Number of parallel environments
|
||||||
|
"""
|
||||||
|
|
||||||
|
observations: np.ndarray
|
||||||
|
actions: np.ndarray
|
||||||
|
rewards: np.ndarray
|
||||||
|
advantages: np.ndarray
|
||||||
|
returns: np.ndarray
|
||||||
|
episode_starts: np.ndarray
|
||||||
|
log_probs: np.ndarray
|
||||||
|
values: np.ndarray
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
buffer_size: int,
|
||||||
|
observation_space: spaces.Space,
|
||||||
|
action_space: spaces.Space,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
gae_lambda: float = 1,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_envs: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||||
|
self.gae_lambda = gae_lambda
|
||||||
|
self.gamma = gamma
|
||||||
|
self.generator_ready = False
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
|
||||||
|
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||||
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.generator_ready = False
|
||||||
|
super().reset()
|
||||||
|
|
||||||
|
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
|
||||||
|
"""
|
||||||
|
Post-processing step: compute the lambda-return (TD(lambda) estimate)
|
||||||
|
and GAE(lambda) advantage.
|
||||||
|
|
||||||
|
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
||||||
|
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
|
||||||
|
where R is the sum of discounted reward with value bootstrap
|
||||||
|
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
|
||||||
|
|
||||||
|
The TD(lambda) estimator has also two special cases:
|
||||||
|
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
|
||||||
|
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
|
||||||
|
|
||||||
|
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
|
||||||
|
|
||||||
|
:param last_values: state value estimation for the last step (one for each env)
|
||||||
|
:param dones: if the last step was a terminal step (one bool for each env).
|
||||||
|
"""
|
||||||
|
# Convert to numpy
|
||||||
|
last_values = last_values.clone().cpu().numpy().flatten()
|
||||||
|
|
||||||
|
last_gae_lam = 0
|
||||||
|
for step in reversed(range(self.buffer_size)):
|
||||||
|
if step == self.buffer_size - 1:
|
||||||
|
next_non_terminal = 1.0 - dones
|
||||||
|
next_values = last_values
|
||||||
|
else:
|
||||||
|
next_non_terminal = 1.0 - self.episode_starts[step + 1]
|
||||||
|
next_values = self.values[step + 1]
|
||||||
|
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
|
||||||
|
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
||||||
|
self.advantages[step] = last_gae_lam
|
||||||
|
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
|
||||||
|
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
|
||||||
|
self.returns = self.advantages + self.values
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
obs: np.ndarray,
|
||||||
|
action: np.ndarray,
|
||||||
|
reward: np.ndarray,
|
||||||
|
episode_start: np.ndarray,
|
||||||
|
value: th.Tensor,
|
||||||
|
log_prob: th.Tensor,
|
||||||
|
mean: th.Tensor,
|
||||||
|
cov_decomp: th.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param obs: Observation
|
||||||
|
:param action: Action
|
||||||
|
:param reward:
|
||||||
|
:param episode_start: Start of episode signal.
|
||||||
|
:param value: estimated value of the current state
|
||||||
|
following the current policy.
|
||||||
|
:param log_prob: log probability of the action
|
||||||
|
following the current policy.
|
||||||
|
"""
|
||||||
|
if len(log_prob.shape) == 0:
|
||||||
|
# Reshape 0-d tensor to avoid error
|
||||||
|
log_prob = log_prob.reshape(-1, 1)
|
||||||
|
|
||||||
|
# Reshape needed when using multiple envs with discrete observations
|
||||||
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||||
|
if isinstance(self.observation_space, spaces.Discrete):
|
||||||
|
obs = obs.reshape((self.n_envs, *self.obs_shape))
|
||||||
|
|
||||||
|
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
||||||
|
action = action.reshape((self.n_envs, self.action_dim))
|
||||||
|
|
||||||
|
self.observations[self.pos] = np.array(obs)
|
||||||
|
self.actions[self.pos] = np.array(action)
|
||||||
|
self.rewards[self.pos] = np.array(reward)
|
||||||
|
self.episode_starts[self.pos] = np.array(episode_start)
|
||||||
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||||
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||||
|
self.means[self.pos] = mean.clone().cpu().numpy()
|
||||||
|
self.cov_decomps[self.pos] = cov_decomp.clone().cpu().numpy()
|
||||||
|
self.pos += 1
|
||||||
|
if self.pos == self.buffer_size:
|
||||||
|
self.full = True
|
||||||
|
|
||||||
|
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
|
||||||
|
assert self.full, ""
|
||||||
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
|
# Prepare the data
|
||||||
|
if not self.generator_ready:
|
||||||
|
_tensor_names = [
|
||||||
|
"observations",
|
||||||
|
"actions",
|
||||||
|
"values",
|
||||||
|
"log_probs",
|
||||||
|
"advantages",
|
||||||
|
"returns",
|
||||||
|
]
|
||||||
|
|
||||||
|
for tensor in _tensor_names:
|
||||||
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||||
|
self.generator_ready = True
|
||||||
|
|
||||||
|
# Return everything, don't create minibatches
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.buffer_size * self.n_envs
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
while start_idx < self.buffer_size * self.n_envs:
|
||||||
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||||
|
start_idx += batch_size
|
||||||
|
|
||||||
|
def _get_samples(
|
||||||
|
self,
|
||||||
|
batch_inds: np.ndarray,
|
||||||
|
env: Optional[VecNormalize] = None,
|
||||||
|
) -> BetterRolloutBufferSamples:
|
||||||
|
data = (
|
||||||
|
self.observations[batch_inds],
|
||||||
|
self.actions[batch_inds],
|
||||||
|
self.values[batch_inds].flatten(),
|
||||||
|
self.log_probs[batch_inds].flatten(),
|
||||||
|
self.means[batch_inds].flatten(),
|
||||||
|
self.cov_decomps[batch_inds].flatten(),
|
||||||
|
self.advantages[batch_inds].flatten(),
|
||||||
|
self.returns[batch_inds].flatten(),
|
||||||
|
)
|
||||||
|
return BetterRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
||||||
|
|
||||||
|
|
||||||
|
class BetterDictRolloutBuffer(DictRolloutBuffer):
|
||||||
|
"""
|
||||||
|
Extended to also save the mean and cov decomp.
|
||||||
|
|
||||||
|
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
||||||
|
Extends the RolloutBuffer to use dictionary observations
|
||||||
|
|
||||||
|
It corresponds to ``buffer_size`` transitions collected
|
||||||
|
using the current policy.
|
||||||
|
This experience will be discarded after the policy update.
|
||||||
|
In order to use PPO objective, we also store the current value of each state
|
||||||
|
and the log probability of each taken action.
|
||||||
|
|
||||||
|
The term rollout here refers to the model-free notion and should not
|
||||||
|
be used with the concept of rollout used in model-based RL or planning.
|
||||||
|
Hence, it is only involved in policy and value function training but not action selection.
|
||||||
|
|
||||||
|
:param buffer_size: Max number of element in the buffer
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param device: PyTorch device
|
||||||
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||||
|
Equivalent to Monte-Carlo advantage estimate when set to 1.
|
||||||
|
:param gamma: Discount factor
|
||||||
|
:param n_envs: Number of parallel environments
|
||||||
|
"""
|
||||||
|
|
||||||
|
observation_space: spaces.Dict
|
||||||
|
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
|
||||||
|
observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
buffer_size: int,
|
||||||
|
observation_space: spaces.Dict,
|
||||||
|
action_space: spaces.Space,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
gae_lambda: float = 1,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_envs: int = 1,
|
||||||
|
):
|
||||||
|
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||||
|
|
||||||
|
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
||||||
|
|
||||||
|
self.gae_lambda = gae_lambda
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
self.generator_ready = False
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.observations = {}
|
||||||
|
for key, obs_input_shape in self.obs_shape.items():
|
||||||
|
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
|
||||||
|
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||||
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
|
self.generator_ready = False
|
||||||
|
super(RolloutBuffer, self).reset()
|
||||||
|
|
||||||
|
def add( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
obs: Dict[str, np.ndarray],
|
||||||
|
action: np.ndarray,
|
||||||
|
reward: np.ndarray,
|
||||||
|
episode_start: np.ndarray,
|
||||||
|
value: th.Tensor,
|
||||||
|
log_prob: th.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param obs: Observation
|
||||||
|
:param action: Action
|
||||||
|
:param reward:
|
||||||
|
:param episode_start: Start of episode signal.
|
||||||
|
:param value: estimated value of the current state
|
||||||
|
following the current policy.
|
||||||
|
:param log_prob: log probability of the action
|
||||||
|
following the current policy.
|
||||||
|
"""
|
||||||
|
if len(log_prob.shape) == 0:
|
||||||
|
# Reshape 0-d tensor to avoid error
|
||||||
|
log_prob = log_prob.reshape(-1, 1)
|
||||||
|
|
||||||
|
for key in self.observations.keys():
|
||||||
|
obs_ = np.array(obs[key])
|
||||||
|
# Reshape needed when using multiple envs with discrete observations
|
||||||
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||||
|
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
||||||
|
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
|
||||||
|
self.observations[key][self.pos] = obs_
|
||||||
|
|
||||||
|
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
||||||
|
action = action.reshape((self.n_envs, self.action_dim))
|
||||||
|
|
||||||
|
self.actions[self.pos] = np.array(action)
|
||||||
|
self.rewards[self.pos] = np.array(reward)
|
||||||
|
self.episode_starts[self.pos] = np.array(episode_start)
|
||||||
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||||
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||||
|
self.pos += 1
|
||||||
|
if self.pos == self.buffer_size:
|
||||||
|
self.full = True
|
||||||
|
|
||||||
|
def get( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
) -> Generator[DictRolloutBufferSamples, None, None]:
|
||||||
|
assert self.full, ""
|
||||||
|
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||||
|
# Prepare the data
|
||||||
|
if not self.generator_ready:
|
||||||
|
for key, obs in self.observations.items():
|
||||||
|
self.observations[key] = self.swap_and_flatten(obs)
|
||||||
|
|
||||||
|
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
|
||||||
|
|
||||||
|
for tensor in _tensor_names:
|
||||||
|
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||||
|
self.generator_ready = True
|
||||||
|
|
||||||
|
# Return everything, don't create minibatches
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.buffer_size * self.n_envs
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
while start_idx < self.buffer_size * self.n_envs:
|
||||||
|
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||||
|
start_idx += batch_size
|
||||||
|
|
||||||
|
def _get_samples( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
batch_inds: np.ndarray,
|
||||||
|
env: Optional[VecNormalize] = None,
|
||||||
|
) -> BetterDictRolloutBufferSamples:
|
||||||
|
return BetterDictRolloutBufferSamples(
|
||||||
|
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||||
|
actions=self.to_torch(self.actions[batch_inds]),
|
||||||
|
old_values=self.to_torch(self.values[batch_inds].flatten()),
|
||||||
|
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
|
||||||
|
mean=self.to_torch(self.means[batch_inds].flatten()),
|
||||||
|
cov_decomp=self.to_torch(self.cov_decomps[batch_inds].flatten()),
|
||||||
|
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
|
||||||
|
returns=self.to_torch(self.returns[batch_inds].flatten()),
|
||||||
|
)
|
@ -8,6 +8,7 @@ from gymnasium import spaces
|
|||||||
|
|
||||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||||
|
from .buffers import BetterDictRolloutBuffer, BetterRolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
@ -70,6 +71,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
use_sde: bool,
|
use_sde: bool,
|
||||||
sde_sample_freq: int,
|
sde_sample_freq: int,
|
||||||
use_pca: bool,
|
use_pca: bool,
|
||||||
|
pca_is: bool,
|
||||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
@ -85,6 +87,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
assert not (use_sde and use_pca)
|
assert not (use_sde and use_pca)
|
||||||
self.use_pca = use_pca
|
self.use_pca = use_pca
|
||||||
|
|
||||||
|
assert not pca_is or use_pca
|
||||||
|
self.pca_is = pca_is
|
||||||
|
|
||||||
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -122,9 +127,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
if self.rollout_buffer_class is None:
|
if self.rollout_buffer_class is None:
|
||||||
if isinstance(self.observation_space, spaces.Dict):
|
if isinstance(self.observation_space, spaces.Dict):
|
||||||
self.rollout_buffer_class = DictRolloutBuffer
|
self.rollout_buffer_class = BetterDictRolloutBuffer
|
||||||
else:
|
else:
|
||||||
self.rollout_buffer_class = RolloutBuffer
|
self.rollout_buffer_class = BetterRolloutBuffer
|
||||||
|
|
||||||
self.rollout_buffer = self.rollout_buffer_class(
|
self.rollout_buffer = self.rollout_buffer_class(
|
||||||
self.n_steps,
|
self.n_steps,
|
||||||
@ -186,7 +191,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Convert to pytorch tensor or to TensorDict
|
# Convert to pytorch tensor or to TensorDict
|
||||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||||
actions, values, log_probs = self.policy(obs_tensor)
|
actions, values, log_probs, distributions = self.policy(obs_tensor, conditioned_log_probs=self.pca_is)
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
# Rescale and perform action
|
# Rescale and perform action
|
||||||
@ -231,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
|
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs
|
||||||
self._last_episode_starts = dones
|
self._last_episode_starts = dones
|
||||||
|
|
||||||
@ -254,6 +259,30 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
|
episode_start: Optional[np.ndarray] = None,
|
||||||
|
deterministic: bool = False,
|
||||||
|
trajectory: th.Tensor = None,
|
||||||
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||||
|
"""
|
||||||
|
Get the policy action from an observation (and optional hidden state).
|
||||||
|
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||||
|
|
||||||
|
:param observation: the input observation
|
||||||
|
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||||
|
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||||
|
this correspond to beginning of episodes,
|
||||||
|
where the hidden states of the RNN must be reset.
|
||||||
|
:param deterministic: Whether or not to return deterministic actions.
|
||||||
|
:param trajectory: Past trajectory. Only required when using PCA.
|
||||||
|
:return: the model's action and the next hidden state
|
||||||
|
(used in recurrent policies)
|
||||||
|
"""
|
||||||
|
return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: SelfOnPolicyAlgorithm,
|
self: SelfOnPolicyAlgorithm,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
|
@ -34,9 +34,12 @@ from stable_baselines3.common.torch_layers import (
|
|||||||
|
|
||||||
from stable_baselines3.common.policies import ContinuousCritic
|
from stable_baselines3.common.policies import ContinuousCritic
|
||||||
|
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import Schedule, RolloutBufferSamples
|
||||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||||
|
|
||||||
|
from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLay, WassersteinProjectionLayer, KLProjectionLayer
|
||||||
|
|
||||||
|
|
||||||
from .distributions import make_proba_distribution
|
from .distributions import make_proba_distribution
|
||||||
from metastable_baselines2.common.pca import PCA_Distribution
|
from metastable_baselines2.common.pca import PCA_Distribution
|
||||||
|
|
||||||
@ -395,7 +398,7 @@ class BasePolicy(BaseModel, ABC):
|
|||||||
class ActorCriticPolicy(BasePolicy):
|
class ActorCriticPolicy(BasePolicy):
|
||||||
"""
|
"""
|
||||||
Policy class for actor-critic algorithms (has both policy and value prediction).
|
Policy class for actor-critic algorithms (has both policy and value prediction).
|
||||||
Used by A2C, PPO and the likes.
|
Used by A2C, PPO, TRPL and the likes.
|
||||||
|
|
||||||
:param observation_space: Observation space
|
:param observation_space: Observation space
|
||||||
:param action_space: Action space
|
:param action_space: Action space
|
||||||
@ -445,6 +448,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
dist_kwargs: Optional[Dict[str, Any]] = {},
|
dist_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
policy_projection: BaseProjectionLayer = IdentityProjectionLayer(),
|
||||||
):
|
):
|
||||||
if optimizer_kwargs is None:
|
if optimizer_kwargs is None:
|
||||||
optimizer_kwargs = {}
|
optimizer_kwargs = {}
|
||||||
@ -514,6 +518,8 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
self.use_pca = use_pca
|
self.use_pca = use_pca
|
||||||
self.dist_kwargs = dist_kwargs
|
self.dist_kwargs = dist_kwargs
|
||||||
|
|
||||||
|
self.policy_projection = policy_projection
|
||||||
|
|
||||||
# Action distribution
|
# Action distribution
|
||||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
|
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
|
||||||
|
|
||||||
@ -624,7 +630,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
# Setup optimizer with initial learning rate
|
# Setup optimizer with initial learning rate
|
||||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
def forward(self, obs: th.Tensor, deterministic: bool = False, conditioned_log_probs: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||||
"""
|
"""
|
||||||
Forward pass in all the networks (actor and critic)
|
Forward pass in all the networks (actor and critic)
|
||||||
|
|
||||||
@ -644,9 +650,13 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
actions = distribution.get_actions(deterministic=deterministic)
|
actions = distribution.get_actions(deterministic=deterministic)
|
||||||
log_prob = distribution.log_prob(actions)
|
if conditioned_log_probs:
|
||||||
|
assert self.use_pca, 'Cannot calculate conditioned log probs when PCA is disabled.'
|
||||||
|
log_prob = distribution.conditioned_log_prob(actions)
|
||||||
|
else:
|
||||||
|
log_prob = distribution.log_prob(actions)
|
||||||
actions = actions.reshape((-1, *self.action_space.shape))
|
actions = actions.reshape((-1, *self.action_space.shape))
|
||||||
return actions, values, log_prob
|
return actions, values, log_prob, distribution
|
||||||
|
|
||||||
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@ -691,7 +701,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action distribution")
|
raise ValueError("Invalid action distribution")
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||||
"""
|
"""
|
||||||
Get the action according to the policy for a given observation.
|
Get the action according to the policy for a given observation.
|
||||||
|
|
||||||
@ -699,19 +709,23 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
:param deterministic: Whether to use stochastic or deterministic actions
|
:param deterministic: Whether to use stochastic or deterministic actions
|
||||||
:return: Taken action according to the policy
|
:return: Taken action according to the policy
|
||||||
"""
|
"""
|
||||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
if self.use_pca:
|
||||||
|
return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
|
||||||
|
else:
|
||||||
|
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||||
|
|
||||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
def evaluate_actions(self, rollout_data: RolloutBufferSamples, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Evaluate actions according to the current policy,
|
Evaluate actions according to the current policy,
|
||||||
given the observations.
|
given the observations.
|
||||||
|
|
||||||
:param obs: Observation
|
:param rollout_data: The Rollouts (containing )
|
||||||
:param actions: Actions
|
:param actions: Actions
|
||||||
:return: estimated value, log likelihood of taking those actions
|
:return: estimated value, log likelihood of taking those actions
|
||||||
and entropy of the action distribution.
|
and entropy of the action distribution.
|
||||||
"""
|
"""
|
||||||
# Preprocess the observation if needed
|
# Preprocess the observation if needed
|
||||||
|
obs = rollout_data.observations
|
||||||
features = self.extract_features(obs)
|
features = self.extract_features(obs)
|
||||||
if self.share_features_extractor:
|
if self.share_features_extractor:
|
||||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
@ -719,11 +733,13 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
pi_features, vf_features = features
|
pi_features, vf_features = features
|
||||||
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||||
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
raw_distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
|
distribution, old_distribution = self.policy_projection.project_from_rollouts(raw_distribution, rollout_data)
|
||||||
log_prob = distribution.log_prob(actions)
|
log_prob = distribution.log_prob(actions)
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
entropy = distribution.entropy()
|
entropy = distribution.entropy()
|
||||||
return values, log_prob, entropy
|
trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution)
|
||||||
|
return values, log_prob, entropy, trust_region_loss
|
||||||
|
|
||||||
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
||||||
"""
|
"""
|
||||||
@ -918,18 +934,18 @@ class Actor(BasePolicy):
|
|||||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
return mean_actions, log_std, {}
|
return mean_actions, log_std, {}
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# Note: the action is squashed
|
# Note: the action is squashed
|
||||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
|
||||||
|
|
||||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# return action and associated log prob
|
# return action and associated log prob
|
||||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||||
return self(observation, deterministic)
|
return self(observation, deterministic, trajectory=trajectory)
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(BasePolicy):
|
class SACPolicy(BasePolicy):
|
||||||
@ -1018,16 +1034,14 @@ class SACPolicy(BasePolicy):
|
|||||||
}
|
}
|
||||||
self.actor_kwargs = self.net_args.copy()
|
self.actor_kwargs = self.net_args.copy()
|
||||||
|
|
||||||
sde_kwargs = {
|
self.actor_kwargs.update({
|
||||||
"use_sde": use_sde,
|
"use_sde": use_sde,
|
||||||
"use_pca": use_pca,
|
"use_pca": use_pca,
|
||||||
"log_std_init": log_std_init,
|
"log_std_init": log_std_init,
|
||||||
"use_expln": use_expln,
|
"use_expln": use_expln,
|
||||||
"clip_mean": clip_mean,
|
"clip_mean": clip_mean,
|
||||||
"dist_kwargs": dist_kwargs,
|
"dist_kwargs": dist_kwargs,
|
||||||
}
|
})
|
||||||
|
|
||||||
self.actor_kwargs.update(sde_kwargs)
|
|
||||||
self.critic_kwargs = self.net_args.copy()
|
self.critic_kwargs = self.net_args.copy()
|
||||||
self.critic_kwargs.update(
|
self.critic_kwargs.update(
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user