Implement Buffer that stores Distribution params
This commit is contained in:
parent
5ed5d32083
commit
1fa66611a3
386
metastable_baselines2/common/buffers.py
Normal file
386
metastable_baselines2/common/buffers.py
Normal file
@ -0,0 +1,386 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
|
||||
from stable_baselines3.common.type_aliases import (
|
||||
DictRolloutBufferSamples,
|
||||
RolloutBufferSamples,
|
||||
TensorDict
|
||||
)
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
try:
|
||||
# Check memory used by replay buffer when possible
|
||||
import psutil
|
||||
except ImportError:
|
||||
psutil = None
|
||||
|
||||
from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer
|
||||
|
||||
class BetterRolloutBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
mean: th.Tensor
|
||||
cov_Decomp: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
|
||||
|
||||
class BetterDictRolloutBufferSamples(NamedTuple):
|
||||
observations: TensorDict
|
||||
actions: th.Tensor
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
mean: th.Tensor
|
||||
cov_Decomp: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
|
||||
class BetterRolloutBuffer(RolloutBuffer):
|
||||
"""
|
||||
Extended to also save the mean and cov decomp.
|
||||
|
||||
Rollout buffer used in on-policy algorithms like A2C/PPO.
|
||||
It corresponds to ``buffer_size`` transitions collected
|
||||
using the current policy.
|
||||
This experience will be discarded after the policy update.
|
||||
In order to use PPO objective, we also store the current value of each state
|
||||
and the log probability of each taken action.
|
||||
|
||||
The term rollout here refers to the model-free notion and should not
|
||||
be used with the concept of rollout used in model-based RL or planning.
|
||||
Hence, it is only involved in policy and value function training but not action selection.
|
||||
|
||||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
observations: np.ndarray
|
||||
actions: np.ndarray
|
||||
rewards: np.ndarray
|
||||
advantages: np.ndarray
|
||||
returns: np.ndarray
|
||||
episode_starts: np.ndarray
|
||||
log_probs: np.ndarray
|
||||
values: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
self.gae_lambda = gae_lambda
|
||||
self.gamma = gamma
|
||||
self.generator_ready = False
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
|
||||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.generator_ready = False
|
||||
super().reset()
|
||||
|
||||
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
|
||||
"""
|
||||
Post-processing step: compute the lambda-return (TD(lambda) estimate)
|
||||
and GAE(lambda) advantage.
|
||||
|
||||
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
||||
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
|
||||
where R is the sum of discounted reward with value bootstrap
|
||||
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
|
||||
|
||||
The TD(lambda) estimator has also two special cases:
|
||||
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
|
||||
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
|
||||
|
||||
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
|
||||
|
||||
:param last_values: state value estimation for the last step (one for each env)
|
||||
:param dones: if the last step was a terminal step (one bool for each env).
|
||||
"""
|
||||
# Convert to numpy
|
||||
last_values = last_values.clone().cpu().numpy().flatten()
|
||||
|
||||
last_gae_lam = 0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = 1.0 - dones
|
||||
next_values = last_values
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.episode_starts[step + 1]
|
||||
next_values = self.values[step + 1]
|
||||
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
|
||||
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
||||
self.advantages[step] = last_gae_lam
|
||||
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
|
||||
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
|
||||
self.returns = self.advantages + self.values
|
||||
|
||||
def add(
|
||||
self,
|
||||
obs: np.ndarray,
|
||||
action: np.ndarray,
|
||||
reward: np.ndarray,
|
||||
episode_start: np.ndarray,
|
||||
value: th.Tensor,
|
||||
log_prob: th.Tensor,
|
||||
mean: th.Tensor,
|
||||
cov_decomp: th.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
:param obs: Observation
|
||||
:param action: Action
|
||||
:param reward:
|
||||
:param episode_start: Start of episode signal.
|
||||
:param value: estimated value of the current state
|
||||
following the current policy.
|
||||
:param log_prob: log probability of the action
|
||||
following the current policy.
|
||||
"""
|
||||
if len(log_prob.shape) == 0:
|
||||
# Reshape 0-d tensor to avoid error
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
# Reshape needed when using multiple envs with discrete observations
|
||||
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||
if isinstance(self.observation_space, spaces.Discrete):
|
||||
obs = obs.reshape((self.n_envs, *self.obs_shape))
|
||||
|
||||
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.observations[self.pos] = np.array(obs)
|
||||
self.actions[self.pos] = np.array(action)
|
||||
self.rewards[self.pos] = np.array(reward)
|
||||
self.episode_starts[self.pos] = np.array(episode_start)
|
||||
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||
self.means[self.pos] = mean.clone().cpu().numpy()
|
||||
self.cov_decomps[self.pos] = cov_decomp.clone().cpu().numpy()
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
_tensor_names = [
|
||||
"observations",
|
||||
"actions",
|
||||
"values",
|
||||
"log_probs",
|
||||
"advantages",
|
||||
"returns",
|
||||
]
|
||||
|
||||
for tensor in _tensor_names:
|
||||
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||
self.generator_ready = True
|
||||
|
||||
# Return everything, don't create minibatches
|
||||
if batch_size is None:
|
||||
batch_size = self.buffer_size * self.n_envs
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < self.buffer_size * self.n_envs:
|
||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> BetterRolloutBufferSamples:
|
||||
data = (
|
||||
self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
self.values[batch_inds].flatten(),
|
||||
self.log_probs[batch_inds].flatten(),
|
||||
self.means[batch_inds].flatten(),
|
||||
self.cov_decomps[batch_inds].flatten(),
|
||||
self.advantages[batch_inds].flatten(),
|
||||
self.returns[batch_inds].flatten(),
|
||||
)
|
||||
return BetterRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
||||
|
||||
|
||||
class BetterDictRolloutBuffer(DictRolloutBuffer):
|
||||
"""
|
||||
Extended to also save the mean and cov decomp.
|
||||
|
||||
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
||||
Extends the RolloutBuffer to use dictionary observations
|
||||
|
||||
It corresponds to ``buffer_size`` transitions collected
|
||||
using the current policy.
|
||||
This experience will be discarded after the policy update.
|
||||
In order to use PPO objective, we also store the current value of each state
|
||||
and the log probability of each taken action.
|
||||
|
||||
The term rollout here refers to the model-free notion and should not
|
||||
be used with the concept of rollout used in model-based RL or planning.
|
||||
Hence, it is only involved in policy and value function training but not action selection.
|
||||
|
||||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to Monte-Carlo advantage estimate when set to 1.
|
||||
:param gamma: Discount factor
|
||||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
observation_space: spaces.Dict
|
||||
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
|
||||
observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
|
||||
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
||||
|
||||
self.gae_lambda = gae_lambda
|
||||
self.gamma = gamma
|
||||
|
||||
self.generator_ready = False
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.observations = {}
|
||||
for key, obs_input_shape in self.obs_shape.items():
|
||||
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
|
||||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.generator_ready = False
|
||||
super(RolloutBuffer, self).reset()
|
||||
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
obs: Dict[str, np.ndarray],
|
||||
action: np.ndarray,
|
||||
reward: np.ndarray,
|
||||
episode_start: np.ndarray,
|
||||
value: th.Tensor,
|
||||
log_prob: th.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
:param obs: Observation
|
||||
:param action: Action
|
||||
:param reward:
|
||||
:param episode_start: Start of episode signal.
|
||||
:param value: estimated value of the current state
|
||||
following the current policy.
|
||||
:param log_prob: log probability of the action
|
||||
following the current policy.
|
||||
"""
|
||||
if len(log_prob.shape) == 0:
|
||||
# Reshape 0-d tensor to avoid error
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
for key in self.observations.keys():
|
||||
obs_ = np.array(obs[key])
|
||||
# Reshape needed when using multiple envs with discrete observations
|
||||
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
||||
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
|
||||
self.observations[key][self.pos] = obs_
|
||||
|
||||
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.actions[self.pos] = np.array(action)
|
||||
self.rewards[self.pos] = np.array(reward)
|
||||
self.episode_starts[self.pos] = np.array(episode_start)
|
||||
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
def get( # type: ignore[override]
|
||||
self,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Generator[DictRolloutBufferSamples, None, None]:
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
if not self.generator_ready:
|
||||
for key, obs in self.observations.items():
|
||||
self.observations[key] = self.swap_and_flatten(obs)
|
||||
|
||||
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
|
||||
|
||||
for tensor in _tensor_names:
|
||||
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
||||
self.generator_ready = True
|
||||
|
||||
# Return everything, don't create minibatches
|
||||
if batch_size is None:
|
||||
batch_size = self.buffer_size * self.n_envs
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < self.buffer_size * self.n_envs:
|
||||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples( # type: ignore[override]
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> BetterDictRolloutBufferSamples:
|
||||
return BetterDictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
actions=self.to_torch(self.actions[batch_inds]),
|
||||
old_values=self.to_torch(self.values[batch_inds].flatten()),
|
||||
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
|
||||
mean=self.to_torch(self.means[batch_inds].flatten()),
|
||||
cov_decomp=self.to_torch(self.cov_decomps[batch_inds].flatten()),
|
||||
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
|
||||
returns=self.to_torch(self.returns[batch_inds].flatten()),
|
||||
)
|
@ -8,6 +8,7 @@ from gymnasium import spaces
|
||||
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
from .buffers import BetterDictRolloutBuffer, BetterRolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
@ -70,6 +71,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
use_pca: bool,
|
||||
pca_is: bool,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stats_window_size: int = 100,
|
||||
@ -85,6 +87,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
assert not (use_sde and use_pca)
|
||||
self.use_pca = use_pca
|
||||
|
||||
assert not pca_is or use_pca
|
||||
self.pca_is = pca_is
|
||||
|
||||
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
||||
|
||||
super().__init__(
|
||||
@ -122,9 +127,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
|
||||
if self.rollout_buffer_class is None:
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.rollout_buffer_class = DictRolloutBuffer
|
||||
self.rollout_buffer_class = BetterDictRolloutBuffer
|
||||
else:
|
||||
self.rollout_buffer_class = RolloutBuffer
|
||||
self.rollout_buffer_class = BetterRolloutBuffer
|
||||
|
||||
self.rollout_buffer = self.rollout_buffer_class(
|
||||
self.n_steps,
|
||||
@ -186,7 +191,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor or to TensorDict
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
actions, values, log_probs = self.policy(obs_tensor)
|
||||
actions, values, log_probs, distributions = self.policy(obs_tensor, conditioned_log_probs=self.pca_is)
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale and perform action
|
||||
@ -231,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
@ -254,6 +259,30 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
trajectory: th.Tensor = None,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:param trajectory: Past trajectory. Only required when using PCA.
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
|
||||
|
||||
def learn(
|
||||
self: SelfOnPolicyAlgorithm,
|
||||
total_timesteps: int,
|
||||
|
@ -34,9 +34,12 @@ from stable_baselines3.common.torch_layers import (
|
||||
|
||||
from stable_baselines3.common.policies import ContinuousCritic
|
||||
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import Schedule, RolloutBufferSamples
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||
|
||||
from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLay, WassersteinProjectionLayer, KLProjectionLayer
|
||||
|
||||
|
||||
from .distributions import make_proba_distribution
|
||||
from metastable_baselines2.common.pca import PCA_Distribution
|
||||
|
||||
@ -395,7 +398,7 @@ class BasePolicy(BaseModel, ABC):
|
||||
class ActorCriticPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
Used by A2C, PPO, TRPL and the likes.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
@ -445,6 +448,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
dist_kwargs: Optional[Dict[str, Any]] = {},
|
||||
policy_projection: BaseProjectionLayer = IdentityProjectionLayer(),
|
||||
):
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
@ -514,6 +518,8 @@ class ActorCriticPolicy(BasePolicy):
|
||||
self.use_pca = use_pca
|
||||
self.dist_kwargs = dist_kwargs
|
||||
|
||||
self.policy_projection = policy_projection
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
|
||||
|
||||
@ -624,7 +630,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False, conditioned_log_probs: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Forward pass in all the networks (actor and critic)
|
||||
|
||||
@ -644,9 +650,13 @@ class ActorCriticPolicy(BasePolicy):
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
if conditioned_log_probs:
|
||||
assert self.use_pca, 'Cannot calculate conditioned log probs when PCA is disabled.'
|
||||
log_prob = distribution.conditioned_log_prob(actions)
|
||||
else:
|
||||
log_prob = distribution.log_prob(actions)
|
||||
actions = actions.reshape((-1, *self.action_space.shape))
|
||||
return actions, values, log_prob
|
||||
return actions, values, log_prob, distribution
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
@ -691,7 +701,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
else:
|
||||
raise ValueError("Invalid action distribution")
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
@ -699,19 +709,23 @@ class ActorCriticPolicy(BasePolicy):
|
||||
:param deterministic: Whether to use stochastic or deterministic actions
|
||||
:return: Taken action according to the policy
|
||||
"""
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||
if self.use_pca:
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
|
||||
else:
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
def evaluate_actions(self, rollout_data: RolloutBufferSamples, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
||||
:param obs: Observation
|
||||
:param rollout_data: The Rollouts (containing )
|
||||
:param actions: Actions
|
||||
:return: estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
obs = rollout_data.observations
|
||||
features = self.extract_features(obs)
|
||||
if self.share_features_extractor:
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
@ -719,11 +733,13 @@ class ActorCriticPolicy(BasePolicy):
|
||||
pi_features, vf_features = features
|
||||
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
||||
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
raw_distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
distribution, old_distribution = self.policy_projection.project_from_rollouts(raw_distribution, rollout_data)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
entropy = distribution.entropy()
|
||||
return values, log_prob, entropy
|
||||
trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution)
|
||||
return values, log_prob, entropy, trust_region_loss
|
||||
|
||||
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
||||
"""
|
||||
@ -918,18 +934,18 @@ class Actor(BasePolicy):
|
||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, {}
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# Note: the action is squashed
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# return action and associated log prob
|
||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self(observation, deterministic)
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||
return self(observation, deterministic, trajectory=trajectory)
|
||||
|
||||
|
||||
class SACPolicy(BasePolicy):
|
||||
@ -1018,16 +1034,14 @@ class SACPolicy(BasePolicy):
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
|
||||
sde_kwargs = {
|
||||
self.actor_kwargs.update({
|
||||
"use_sde": use_sde,
|
||||
"use_pca": use_pca,
|
||||
"log_std_init": log_std_init,
|
||||
"use_expln": use_expln,
|
||||
"clip_mean": clip_mean,
|
||||
"dist_kwargs": dist_kwargs,
|
||||
}
|
||||
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
})
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs.update(
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user