1154 lines
48 KiB
Python
1154 lines
48 KiB
Python
"""Policies: abstract base class and concrete implementations."""
|
|
|
|
import collections
|
|
import copy
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from functools import partial
|
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gymnasium import spaces
|
|
from torch import nn
|
|
import math
|
|
|
|
from stable_baselines3.common.distributions import (
|
|
BernoulliDistribution,
|
|
CategoricalDistribution,
|
|
DiagGaussianDistribution,
|
|
Distribution,
|
|
MultiCategoricalDistribution,
|
|
StateDependentNoiseDistribution,
|
|
SquashedDiagGaussianDistribution,
|
|
)
|
|
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs
|
|
from stable_baselines3.common.torch_layers import (
|
|
BaseFeaturesExtractor,
|
|
CombinedExtractor,
|
|
FlattenExtractor,
|
|
MlpExtractor,
|
|
NatureCNN,
|
|
create_mlp,
|
|
get_actor_critic_arch,
|
|
)
|
|
|
|
from stable_baselines3.common.policies import ContinuousCritic
|
|
|
|
from stable_baselines3.common.type_aliases import Schedule, RolloutBufferSamples
|
|
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
|
|
|
from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
|
|
|
|
|
from .distributions import make_proba_distribution
|
|
from metastable_baselines2.common.pca import PCA_Distribution
|
|
|
|
|
|
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
|
|
|
|
|
class BaseModel(nn.Module):
|
|
"""
|
|
The base model object: makes predictions in response to observations.
|
|
|
|
In the case of policies, the prediction is an action. In the case of critics, it is the
|
|
estimated value of the observation.
|
|
|
|
:param observation_space: The observation space of the environment
|
|
:param action_space: The action space of the environment
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param features_extractor: Network to extract features
|
|
(a CNN when using images, a nn.Flatten() layer otherwise)
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
features_extractor: Optional[BaseFeaturesExtractor] = None,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
if optimizer_kwargs is None:
|
|
optimizer_kwargs = {}
|
|
|
|
if features_extractor_kwargs is None:
|
|
features_extractor_kwargs = {}
|
|
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
self.features_extractor = features_extractor
|
|
self.normalize_images = normalize_images
|
|
|
|
self.optimizer_class = optimizer_class
|
|
self.optimizer_kwargs = optimizer_kwargs
|
|
self.optimizer: th.optim.Optimizer
|
|
|
|
self.features_extractor_class = features_extractor_class
|
|
self.features_extractor_kwargs = features_extractor_kwargs
|
|
# Automatically deactivate dtype and bounds checks
|
|
if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
|
|
self.features_extractor_kwargs.update(dict(normalized_image=True))
|
|
|
|
def _update_features_extractor(
|
|
self,
|
|
net_kwargs: Dict[str, Any],
|
|
features_extractor: Optional[BaseFeaturesExtractor] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Update the network keyword arguments and create a new features extractor object if needed.
|
|
If a ``features_extractor`` object is passed, then it will be shared.
|
|
|
|
:param net_kwargs: the base network keyword arguments, without the ones
|
|
related to features extractor
|
|
:param features_extractor: a features extractor object.
|
|
If None, a new object will be created.
|
|
:return: The updated keyword arguments
|
|
"""
|
|
net_kwargs = net_kwargs.copy()
|
|
if features_extractor is None:
|
|
# The features extractor is not shared, create a new one
|
|
features_extractor = self.make_features_extractor()
|
|
net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
|
|
return net_kwargs
|
|
|
|
def make_features_extractor(self) -> BaseFeaturesExtractor:
|
|
"""Helper method to create a features extractor."""
|
|
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
|
|
|
def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
|
|
"""
|
|
Preprocess the observation if needed and extract features.
|
|
|
|
:param obs: The observation
|
|
:param features_extractor: The features extractor to use.
|
|
:return: The extracted features
|
|
"""
|
|
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
|
return features_extractor(preprocessed_obs)
|
|
|
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
|
"""
|
|
Get data that need to be saved in order to re-create the model when loading it from disk.
|
|
|
|
:return: The dictionary to pass to the as kwargs constructor when reconstruction this model.
|
|
"""
|
|
return dict(
|
|
observation_space=self.observation_space,
|
|
action_space=self.action_space,
|
|
# Passed to the constructor by child class
|
|
# squash_output=self.squash_output,
|
|
# features_extractor=self.features_extractor
|
|
normalize_images=self.normalize_images,
|
|
)
|
|
|
|
@property
|
|
def device(self) -> th.device:
|
|
"""Infer which device this policy lives on by inspecting its parameters.
|
|
If it has no parameters, the 'cpu' device is used as a fallback.
|
|
|
|
:return:"""
|
|
for param in self.parameters():
|
|
return param.device
|
|
return get_device("cpu")
|
|
|
|
def save(self, path: str) -> None:
|
|
"""
|
|
Save model to a given location.
|
|
|
|
:param path:
|
|
"""
|
|
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
|
|
|
|
@classmethod
|
|
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
|
|
"""
|
|
Load model from path.
|
|
|
|
:param path:
|
|
:param device: Device on which the policy should be loaded.
|
|
:return:
|
|
"""
|
|
device = get_device(device)
|
|
saved_variables = th.load(path, map_location=device)
|
|
|
|
# Create policy object
|
|
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
|
|
# Load weights
|
|
model.load_state_dict(saved_variables["state_dict"])
|
|
model.to(device)
|
|
return model
|
|
|
|
def load_from_vector(self, vector: np.ndarray) -> None:
|
|
"""
|
|
Load parameters from a 1D vector.
|
|
|
|
:param vector:
|
|
"""
|
|
th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters())
|
|
|
|
def parameters_to_vector(self) -> np.ndarray:
|
|
"""
|
|
Convert the parameters to a 1D vector.
|
|
|
|
:return:
|
|
"""
|
|
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
|
|
|
def set_training_mode(self, mode: bool) -> None:
|
|
"""
|
|
Put the policy in either training or evaluation mode.
|
|
|
|
This affects certain modules, such as batch normalisation and dropout.
|
|
|
|
:param mode: if true, set to training mode, else set to evaluation mode
|
|
"""
|
|
self.train(mode)
|
|
|
|
def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool:
|
|
"""
|
|
Check whether or not the observation is vectorized,
|
|
apply transposition to image (so that they are channel-first) if needed.
|
|
This is used in DQN when sampling random action (epsilon-greedy policy)
|
|
|
|
:param observation: the input observation to check
|
|
:return: whether the given observation is vectorized or not
|
|
"""
|
|
vectorized_env = False
|
|
if isinstance(observation, dict):
|
|
for key, obs in observation.items():
|
|
obs_space = self.observation_space.spaces[key]
|
|
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
|
|
else:
|
|
vectorized_env = is_vectorized_observation(
|
|
maybe_transpose(observation, self.observation_space), self.observation_space
|
|
)
|
|
return vectorized_env
|
|
|
|
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]:
|
|
"""
|
|
Convert an input observation to a PyTorch tensor that can be fed to a model.
|
|
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
|
|
|
:param observation: the input observation
|
|
:return: The observation as PyTorch tensor
|
|
and whether the observation is vectorized or not
|
|
"""
|
|
vectorized_env = False
|
|
if isinstance(observation, dict):
|
|
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
|
|
observation = copy.deepcopy(observation)
|
|
for key, obs in observation.items():
|
|
obs_space = self.observation_space.spaces[key]
|
|
if is_image_space(obs_space):
|
|
obs_ = maybe_transpose(obs, obs_space)
|
|
else:
|
|
obs_ = np.array(obs)
|
|
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
|
|
# Add batch dimension if needed
|
|
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))
|
|
|
|
elif is_image_space(self.observation_space):
|
|
# Handle the different cases for images
|
|
# as PyTorch use channel first format
|
|
observation = maybe_transpose(observation, self.observation_space)
|
|
|
|
else:
|
|
observation = np.array(observation)
|
|
|
|
if not isinstance(observation, dict):
|
|
# Dict obs need to be handled separately
|
|
vectorized_env = is_vectorized_observation(observation, self.observation_space)
|
|
# Add batch dimension if needed
|
|
observation = observation.reshape((-1, *self.observation_space.shape))
|
|
|
|
observation = obs_as_tensor(observation, self.device)
|
|
return observation, vectorized_env
|
|
|
|
|
|
class BasePolicy(BaseModel, ABC):
|
|
"""The base policy object.
|
|
|
|
Parameters are mostly the same as `BaseModel`; additions are documented below.
|
|
|
|
:param args: positional arguments passed through to `BaseModel`.
|
|
:param kwargs: keyword arguments passed through to `BaseModel`.
|
|
:param squash_output: For continuous actions, whether the output is squashed
|
|
or not using a ``tanh()`` function.
|
|
"""
|
|
|
|
def __init__(self, *args, squash_output: bool = False, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._squash_output = squash_output
|
|
|
|
@staticmethod
|
|
def _dummy_schedule(progress_remaining: float) -> float:
|
|
"""(float) Useful for pickling policy."""
|
|
del progress_remaining
|
|
return 0.0
|
|
|
|
@property
|
|
def squash_output(self) -> bool:
|
|
"""(bool) Getter for squash_output."""
|
|
return self._squash_output
|
|
|
|
@staticmethod
|
|
def init_weights(module: nn.Module, gain: float = 1) -> None:
|
|
"""
|
|
Orthogonal initialization (used in PPO and A2C)
|
|
"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
nn.init.orthogonal_(module.weight, gain=gain)
|
|
if module.bias is not None:
|
|
module.bias.data.fill_(0.0)
|
|
|
|
@abstractmethod
|
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
|
"""
|
|
Get the action according to the policy for a given observation.
|
|
|
|
By default provides a dummy implementation -- not all BasePolicy classes
|
|
implement this, e.g. if they are a Critic in an Actor-Critic method.
|
|
|
|
:param observation:
|
|
:param deterministic: Whether to use stochastic or deterministic actions
|
|
:return: Taken action according to the policy
|
|
"""
|
|
|
|
def predict(
|
|
self,
|
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
|
episode_start: Optional[np.ndarray] = None,
|
|
deterministic: bool = False,
|
|
trajectory = None,
|
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
|
"""
|
|
Get the policy action from an observation (and optional hidden state).
|
|
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
|
|
|
:param observation: the input observation
|
|
:param state: The last hidden states (can be None, used in recurrent policies)
|
|
:param episode_start: The last masks (can be None, used in recurrent policies)
|
|
this correspond to beginning of episodes,
|
|
where the hidden states of the RNN must be reset.
|
|
:param deterministic: Whether or not to return deterministic actions.
|
|
:return: the model's action and the next hidden state
|
|
(used in recurrent policies)
|
|
"""
|
|
# Switch to eval mode (this affects batch norm / dropout)
|
|
self.set_training_mode(False)
|
|
|
|
observation, vectorized_env = self.obs_to_tensor(observation)
|
|
|
|
with th.no_grad():
|
|
actions = self._predict(observation, deterministic=deterministic)
|
|
# Convert to numpy, and reshape to the original action shape
|
|
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape))
|
|
|
|
if isinstance(self.action_space, spaces.Box):
|
|
if self.squash_output:
|
|
# Rescale to proper domain when using squashing
|
|
actions = self.unscale_action(actions)
|
|
else:
|
|
# Actions could be on arbitrary scale, so clip the actions to avoid
|
|
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
|
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
|
|
|
# Remove batch dimension if needed
|
|
if not vectorized_env:
|
|
actions = actions.squeeze(axis=0)
|
|
|
|
return actions, state
|
|
|
|
def scale_action(self, action: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Rescale the action from [low, high] to [-1, 1]
|
|
(no need for symmetric action space)
|
|
|
|
:param action: Action to scale
|
|
:return: Scaled action
|
|
"""
|
|
low, high = self.action_space.low, self.action_space.high
|
|
return 2.0 * ((action - low) / (high - low)) - 1.0
|
|
|
|
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Rescale the action from [-1, 1] to [low, high]
|
|
(no need for symmetric action space)
|
|
|
|
:param scaled_action: Action to un-scale
|
|
"""
|
|
low, high = self.action_space.low, self.action_space.high
|
|
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
|
|
|
|
|
class ActorCriticPolicy(BasePolicy):
|
|
"""
|
|
Policy class for actor-critic algorithms (has both policy and value prediction).
|
|
Used by A2C, PPO, TRPL and the likes.
|
|
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param ortho_init: Whether to use or not orthogonal initialization
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param squash_output: Whether to squash the output using a tanh function,
|
|
this allows to ensure boundaries when using gSDE.
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
|
ortho_init: bool = True,
|
|
use_sde: bool = False,
|
|
log_std_init: float = 0.0,
|
|
use_pca: bool = False,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
squash_output: bool = False,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
share_features_extractor: bool = True,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
dist_kwargs: Optional[Dict[str, Any]] = {},
|
|
policy_projection: BaseProjectionLayer = IdentityProjectionLayer(),
|
|
):
|
|
if optimizer_kwargs is None:
|
|
optimizer_kwargs = {}
|
|
# Small values to avoid NaN in Adam optimizer
|
|
if optimizer_class == th.optim.Adam:
|
|
optimizer_kwargs["eps"] = 1e-5
|
|
|
|
if activation_fn == 'ReLU':
|
|
activation_fn = nn.ReLU
|
|
elif activation_fn == 'tanh':
|
|
activation_fn = nn.Tanh
|
|
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
optimizer_class=optimizer_class,
|
|
optimizer_kwargs=optimizer_kwargs,
|
|
squash_output=squash_output,
|
|
normalize_images=normalize_images,
|
|
)
|
|
|
|
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
|
|
warnings.warn(
|
|
(
|
|
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
|
|
"you should now pass directly a dictionary and not a list "
|
|
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
|
),
|
|
)
|
|
net_arch = net_arch[0]
|
|
|
|
# Default network architecture, from stable-baselines
|
|
if net_arch is None:
|
|
if features_extractor_class == NatureCNN:
|
|
net_arch = []
|
|
else:
|
|
net_arch = dict(pi=[64, 64], vf=[64, 64])
|
|
|
|
self.net_arch = net_arch
|
|
self.activation_fn = activation_fn
|
|
self.ortho_init = ortho_init
|
|
|
|
self.share_features_extractor = share_features_extractor
|
|
self.features_extractor = self.make_features_extractor()
|
|
self.features_dim = self.features_extractor.features_dim
|
|
if self.share_features_extractor:
|
|
self.pi_features_extractor = self.features_extractor
|
|
self.vf_features_extractor = self.features_extractor
|
|
else:
|
|
self.pi_features_extractor = self.features_extractor
|
|
self.vf_features_extractor = self.make_features_extractor()
|
|
|
|
self.log_std_init = log_std_init
|
|
# Keyword arguments for gSDE distribution
|
|
if use_sde:
|
|
add_dist_kwargs = {
|
|
"full_std": full_std,
|
|
"squash_output": squash_output,
|
|
"use_expln": use_expln,
|
|
"learn_features": False,
|
|
}
|
|
dist_kwargs.update(add_dist_kwargs)
|
|
if use_pca:
|
|
add_dist_kwargs = {
|
|
"init_std": math.exp(self.log_std_init)
|
|
}
|
|
dist_kwargs.update(add_dist_kwargs)
|
|
|
|
self.use_sde = use_sde
|
|
self.use_pca = use_pca
|
|
self.dist_kwargs = dist_kwargs
|
|
|
|
self.policy_projection = policy_projection
|
|
|
|
self.n_envs = dist_kwargs.pop('n_envs', 1)
|
|
|
|
# Action distribution
|
|
self.action_dist = make_proba_distribution(action_space,use_sde=use_sde, use_pca=use_pca, n_envs=self.n_envs, dist_kwargs=dist_kwargs)
|
|
|
|
self._build(lr_schedule)
|
|
|
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
|
data = super()._get_constructor_parameters()
|
|
|
|
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
|
|
|
|
data.update(
|
|
dict(
|
|
net_arch=self.net_arch,
|
|
activation_fn=self.activation_fn,
|
|
use_sde=self.use_sde,
|
|
log_std_init=self.log_std_init,
|
|
squash_output=default_none_kwargs["squash_output"],
|
|
full_std=default_none_kwargs["full_std"],
|
|
use_expln=default_none_kwargs["use_expln"],
|
|
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
|
ortho_init=self.ortho_init,
|
|
optimizer_class=self.optimizer_class,
|
|
optimizer_kwargs=self.optimizer_kwargs,
|
|
features_extractor_class=self.features_extractor_class,
|
|
features_extractor_kwargs=self.features_extractor_kwargs,
|
|
)
|
|
)
|
|
return data
|
|
|
|
def reset_noise(self, n_envs: int = 1) -> None:
|
|
"""
|
|
Sample new weights for the exploration matrix.
|
|
|
|
:param n_envs:
|
|
"""
|
|
if isinstance(self.action_dist, StateDependentNoiseDistribution):
|
|
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
|
else:
|
|
self.action_dist.base_noise.reset()
|
|
|
|
def _build_mlp_extractor(self) -> None:
|
|
"""
|
|
Create the policy and value networks.
|
|
Part of the layers can be shared.
|
|
"""
|
|
# Note: If net_arch is None and some features extractor is used,
|
|
# net_arch here is an empty list and mlp_extractor does not
|
|
# really contain any layers (acts like an identity module).
|
|
self.mlp_extractor = MlpExtractor(
|
|
self.features_dim,
|
|
net_arch=self.net_arch,
|
|
activation_fn=self.activation_fn,
|
|
device=self.device,
|
|
)
|
|
|
|
def _build(self, lr_schedule: Schedule) -> None:
|
|
"""
|
|
Create the networks and the optimizer.
|
|
|
|
:param lr_schedule: Learning rate schedule
|
|
lr_schedule(1) is the initial learning rate
|
|
"""
|
|
self._build_mlp_extractor()
|
|
|
|
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
|
|
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
|
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
|
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
|
|
)
|
|
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
|
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
|
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
|
|
)
|
|
elif isinstance(self.action_dist, PCA_Distribution):
|
|
self.action_net, self.std_net = self.action_dist.proba_distribution_net(
|
|
latent_dim=latent_dim_pi
|
|
)
|
|
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
|
|
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
|
|
|
|
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
|
|
# Init weights: use orthogonal initialization
|
|
# with small initial weight for the output
|
|
if self.ortho_init:
|
|
# TODO: check for features_extractor
|
|
# Values from stable-baselines.
|
|
# features_extractor/mlp values are
|
|
# originally from openai/baselines (default gains/init_scales).
|
|
module_gains = {
|
|
self.features_extractor: np.sqrt(2),
|
|
self.mlp_extractor: np.sqrt(2),
|
|
self.action_net: 0.01,
|
|
self.value_net: 1,
|
|
}
|
|
if not self.share_features_extractor:
|
|
# Note(antonin): this is to keep SB3 results
|
|
# consistent, see GH#1148
|
|
del module_gains[self.features_extractor]
|
|
module_gains[self.pi_features_extractor] = np.sqrt(2)
|
|
module_gains[self.vf_features_extractor] = np.sqrt(2)
|
|
|
|
for module, gain in module_gains.items():
|
|
module.apply(partial(self.init_weights, gain=gain))
|
|
|
|
# Setup optimizer with initial learning rate
|
|
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
|
|
|
def forward(self, obs: th.Tensor, deterministic: bool = False, conditioned_log_probs: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
|
"""
|
|
Forward pass in all the networks (actor and critic)
|
|
|
|
:param obs: Observation
|
|
:param deterministic: Whether to sample or use deterministic actions
|
|
:return: action, value and log probability of the action
|
|
"""
|
|
# Preprocess the observation if needed
|
|
features = self.extract_features(obs)
|
|
if self.share_features_extractor:
|
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
|
else:
|
|
pi_features, vf_features = features
|
|
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
|
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
|
# Evaluate the values for the given observations
|
|
values = self.value_net(latent_vf)
|
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
|
actions = distribution.get_actions(deterministic=deterministic)
|
|
if conditioned_log_probs:
|
|
assert self.use_pca, 'Cannot calculate conditioned log probs when PCA is disabled.'
|
|
log_prob = distribution.conditioned_log_prob(actions)
|
|
else:
|
|
log_prob = distribution.log_prob(actions)
|
|
actions = actions.reshape((-1, *self.action_space.shape))
|
|
return actions, values, log_prob, distribution
|
|
|
|
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
|
"""
|
|
Preprocess the observation if needed and extract features.
|
|
|
|
:param obs: Observation
|
|
:return: the output of the features extractor(s)
|
|
"""
|
|
if self.share_features_extractor:
|
|
return super().extract_features(obs, self.features_extractor)
|
|
else:
|
|
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
|
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
|
return pi_features, vf_features
|
|
|
|
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
|
|
"""
|
|
Retrieve action distribution given the latent codes.
|
|
|
|
:param latent_pi: Latent code for the actor
|
|
:return: Action distribution
|
|
"""
|
|
mean_actions = self.action_net(latent_pi)
|
|
|
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
|
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
|
elif isinstance(self.action_dist, CategoricalDistribution):
|
|
# Here mean_actions are the logits before the softmax
|
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
|
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
|
# Here mean_actions are the flattened logits
|
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
|
elif isinstance(self.action_dist, BernoulliDistribution):
|
|
# Here mean_actions are the logits (before rounding to get the binary actions)
|
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
|
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
|
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
|
elif isinstance(self.action_dist, PCA_Distribution):
|
|
std_actions = self.std_net(latent_pi)
|
|
self.log_std = th.log(std_actions)
|
|
return self.action_dist.proba_distribution(mean_actions, std_actions)
|
|
else:
|
|
raise ValueError("Invalid action distribution")
|
|
|
|
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
|
"""
|
|
Get the action according to the policy for a given observation.
|
|
|
|
:param observation:
|
|
:param deterministic: Whether to use stochastic or deterministic actions
|
|
:return: Taken action according to the policy
|
|
"""
|
|
if self.use_pca:
|
|
return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
|
|
else:
|
|
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
|
|
|
def evaluate_actions(self, rollout_data: RolloutBufferSamples, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
|
"""
|
|
Evaluate actions according to the current policy,
|
|
given the observations.
|
|
|
|
:param rollout_data: The Rollouts (containing )
|
|
:param actions: Actions
|
|
:return: estimated value, log likelihood of taking those actions
|
|
and entropy of the action distribution.
|
|
"""
|
|
# Preprocess the observation if needed
|
|
obs = rollout_data.observations
|
|
features = self.extract_features(obs)
|
|
if self.share_features_extractor:
|
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
|
else:
|
|
pi_features, vf_features = features
|
|
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
|
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
|
raw_distribution = self._get_action_dist_from_latent(latent_pi)
|
|
distribution, old_distribution = self.policy_projection.project_from_rollouts(raw_distribution, rollout_data)
|
|
log_prob = distribution.log_prob(actions)
|
|
values = self.value_net(latent_vf)
|
|
entropy = distribution.entropy()
|
|
trust_region_loss = self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)
|
|
return values, log_prob, entropy, trust_region_loss
|
|
|
|
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
|
"""
|
|
Get the current policy distribution given the observations.
|
|
|
|
:param obs:
|
|
:return: the action distribution.
|
|
"""
|
|
features = super().extract_features(obs, self.pi_features_extractor)
|
|
latent_pi = self.mlp_extractor.forward_actor(features)
|
|
return self._get_action_dist_from_latent(latent_pi)
|
|
|
|
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
|
"""
|
|
Get the estimated values according to the current policy given the observations.
|
|
|
|
:param obs: Observation
|
|
:return: the estimated values.
|
|
"""
|
|
features = super().extract_features(obs, self.vf_features_extractor)
|
|
latent_vf = self.mlp_extractor.forward_critic(features)
|
|
return self.value_net(latent_vf)
|
|
|
|
|
|
LOG_STD_MIN, LOG_STD_MAX = 0.1, 1000
|
|
|
|
|
|
class Actor(BasePolicy):
|
|
"""
|
|
Actor network (policy) for SAC.
|
|
|
|
:param observation_space: Obervation space
|
|
:param action_space: Action space
|
|
:param net_arch: Network architecture
|
|
:param features_extractor: Network to extract features
|
|
(a CNN when using images, a nn.Flatten() layer otherwise)
|
|
:param features_dim: Number of features
|
|
:param activation_fn: Activation function
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE.
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
"""
|
|
|
|
action_space: spaces.Box
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Box,
|
|
net_arch: List[int],
|
|
features_extractor: nn.Module,
|
|
features_dim: int,
|
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
|
use_sde: bool = False,
|
|
log_std_init: float = -3,
|
|
use_pca: bool = False,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
clip_mean: float = 2.0,
|
|
normalize_images: bool = True,
|
|
dist_kwargs={},
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
features_extractor=features_extractor,
|
|
normalize_images=normalize_images,
|
|
squash_output=True,
|
|
)
|
|
|
|
if activation_fn == 'ReLU':
|
|
activation_fn = nn.ReLU
|
|
elif activation_fn == 'tanh':
|
|
activation_fn = nn.Tanh
|
|
|
|
self.use_sde = use_sde
|
|
self.use_pca = use_pca
|
|
self.sde_features_extractor = None
|
|
self.net_arch = net_arch
|
|
self.features_dim = features_dim
|
|
self.activation_fn = activation_fn
|
|
self.log_std_init = log_std_init
|
|
self.use_expln = use_expln
|
|
self.full_std = full_std
|
|
self.clip_mean = clip_mean
|
|
|
|
action_dim = get_action_dim(self.action_space)
|
|
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
|
|
self.latent_pi = nn.Sequential(*latent_pi_net)
|
|
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
|
|
|
assert not (self.use_sde and self.use_pca)
|
|
|
|
if self.use_pca:
|
|
self.action_dist = PCA_Distribution(
|
|
action_dim, **dist_kwargs
|
|
)
|
|
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
|
latent_dim=last_layer_dim, return_log_std=True
|
|
)
|
|
self._remember_log_std = th.Tensor([log_std_init])
|
|
# Avoid numerical issues by limiting the mean of the Gaussian
|
|
# to be in [-clip_mean, clip_mean]
|
|
if clip_mean > 0.0:
|
|
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
|
|
elif self.use_sde:
|
|
self.action_dist = StateDependentNoiseDistribution(
|
|
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
|
)
|
|
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
|
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
|
|
)
|
|
# Avoid numerical issues by limiting the mean of the Gaussian
|
|
# to be in [-clip_mean, clip_mean]
|
|
if clip_mean > 0.0:
|
|
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
|
|
else:
|
|
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
|
|
self.mu = nn.Linear(last_layer_dim, action_dim)
|
|
self.log_std = nn.Linear(last_layer_dim, action_dim)
|
|
|
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
|
data = super()._get_constructor_parameters()
|
|
|
|
data.update(
|
|
dict(
|
|
net_arch=self.net_arch,
|
|
features_dim=self.features_dim,
|
|
activation_fn=self.activation_fn,
|
|
use_sde=self.use_sde,
|
|
log_std_init=self.log_std_init,
|
|
full_std=self.full_std,
|
|
use_expln=self.use_expln,
|
|
features_extractor=self.features_extractor,
|
|
clip_mean=self.clip_mean,
|
|
)
|
|
)
|
|
return data
|
|
|
|
def get_std(self) -> th.Tensor:
|
|
"""
|
|
Retrieve the standard deviation of the action distribution.
|
|
Only useful when using gSDE.
|
|
It corresponds to ``th.exp(log_std)`` in the normal case,
|
|
but is slightly different when using ``expln`` function
|
|
(cf StateDependentNoiseDistribution doc).
|
|
|
|
:return:
|
|
"""
|
|
if isinstance(self.action_dist, StateDependentNoiseDistribution):
|
|
return self.action_dist.get_std(self.log_std)
|
|
else:
|
|
return th.exp(self._remember_log_std)
|
|
|
|
def reset_noise(self, batch_size: int = 1) -> None:
|
|
"""
|
|
Sample new weights for the exploration matrix, when using gSDE.
|
|
|
|
:param batch_size:
|
|
"""
|
|
msg = "reset_noise() is only available when using gSDE"
|
|
if isinstance(self.action_dist, StateDependentNoiseDistribution):
|
|
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
|
else:
|
|
self.action_dist.base_noise.reset()
|
|
|
|
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
|
"""
|
|
Get the parameters for the action distribution.
|
|
|
|
:param obs:
|
|
:return:
|
|
Mean, standard deviation and optional keyword arguments.
|
|
"""
|
|
features = self.extract_features(obs, self.features_extractor)
|
|
latent_pi = self.latent_pi(features)
|
|
mean_actions = self.mu(latent_pi)
|
|
|
|
if self.use_sde:
|
|
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
|
# Unstructured exploration (Original implementation)
|
|
log_std = self.log_std(latent_pi)
|
|
self._remember_log_std = log_std
|
|
# Original Implementation to cap the standard deviation
|
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
|
return mean_actions, log_std, {}
|
|
|
|
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
|
# Note: the action is squashed
|
|
if isinstance(self.action_dist, PCA_Distribution):
|
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
|
|
else:
|
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
|
|
|
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
|
# return action and associated log prob
|
|
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
|
|
|
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
|
return self(observation, deterministic, trajectory=trajectory)
|
|
|
|
|
|
class SACPolicy(BasePolicy):
|
|
"""
|
|
Policy class (with both actor and critic) for SAC.
|
|
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
:param n_critics: Number of critic networks to create.
|
|
:param share_features_extractor: Whether to share or not the features extractor
|
|
between the actor and the critic (this saves computation time)
|
|
"""
|
|
|
|
actor: Actor
|
|
critic: ContinuousCritic
|
|
critic_target: ContinuousCritic
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Box,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
|
use_sde: bool = False,
|
|
log_std_init: float = -3,
|
|
use_pca: bool = False,
|
|
use_expln: bool = False,
|
|
clip_mean: float = 2.0,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
n_critics: int = 2,
|
|
share_features_extractor: bool = False,
|
|
dist_kwargs={},
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
optimizer_class=optimizer_class,
|
|
optimizer_kwargs=optimizer_kwargs,
|
|
squash_output=True,
|
|
normalize_images=normalize_images,
|
|
)
|
|
|
|
if activation_fn == 'ReLU':
|
|
activation_fn = nn.ReLU
|
|
elif activation_fn == 'tanh':
|
|
activation_fn = nn.Tanh
|
|
|
|
if net_arch is None:
|
|
net_arch = [256, 256]
|
|
|
|
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
|
|
|
self.net_arch = net_arch
|
|
self.activation_fn = activation_fn
|
|
self.net_args = {
|
|
"observation_space": self.observation_space,
|
|
"action_space": self.action_space,
|
|
"net_arch": actor_arch,
|
|
"activation_fn": self.activation_fn,
|
|
"normalize_images": normalize_images,
|
|
}
|
|
self.actor_kwargs = self.net_args.copy()
|
|
|
|
self.actor_kwargs.update({
|
|
"use_sde": use_sde,
|
|
"use_pca": use_pca,
|
|
"log_std_init": log_std_init,
|
|
"use_expln": use_expln,
|
|
"clip_mean": clip_mean,
|
|
"dist_kwargs": dist_kwargs,
|
|
})
|
|
self.critic_kwargs = self.net_args.copy()
|
|
self.critic_kwargs.update(
|
|
{
|
|
"n_critics": n_critics,
|
|
"net_arch": critic_arch,
|
|
"share_features_extractor": share_features_extractor,
|
|
}
|
|
)
|
|
|
|
self.actor, self.actor_target = None, None
|
|
self.critic, self.critic_target = None, None
|
|
self.share_features_extractor = share_features_extractor
|
|
|
|
self._build(lr_schedule)
|
|
|
|
def _build(self, lr_schedule: Schedule) -> None:
|
|
self.actor = self.make_actor()
|
|
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
|
|
|
if self.share_features_extractor:
|
|
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
|
# Do not optimize the shared features extractor with the critic loss
|
|
# otherwise, there are gradient computation issues
|
|
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
|
else:
|
|
# Create a separate features extractor for the critic
|
|
# this requires more memory and computation
|
|
self.critic = self.make_critic(features_extractor=None)
|
|
critic_parameters = list(self.critic.parameters())
|
|
|
|
# Critic target should not share the features extractor with critic
|
|
self.critic_target = self.make_critic(features_extractor=None)
|
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
|
|
|
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
|
|
|
# Target networks should always be in eval mode
|
|
self.critic_target.set_training_mode(False)
|
|
|
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
|
data = super()._get_constructor_parameters()
|
|
|
|
data.update(
|
|
dict(
|
|
net_arch=self.net_arch,
|
|
activation_fn=self.net_args["activation_fn"],
|
|
use_sde=self.actor_kwargs["use_sde"],
|
|
log_std_init=self.actor_kwargs["log_std_init"],
|
|
use_expln=self.actor_kwargs["use_expln"],
|
|
clip_mean=self.actor_kwargs["clip_mean"],
|
|
n_critics=self.critic_kwargs["n_critics"],
|
|
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
|
optimizer_class=self.optimizer_class,
|
|
optimizer_kwargs=self.optimizer_kwargs,
|
|
features_extractor_class=self.features_extractor_class,
|
|
features_extractor_kwargs=self.features_extractor_kwargs,
|
|
)
|
|
)
|
|
return data
|
|
|
|
def reset_noise(self, batch_size: int = 1) -> None:
|
|
"""
|
|
Sample new weights for the exploration matrix, when using gSDE.
|
|
|
|
:param batch_size:
|
|
"""
|
|
if isinstance(self.action_space, StateDependentNoiseDistribution):
|
|
self.actor.reset_noise(batch_size=batch_size)
|
|
else:
|
|
self.actor.reset_noise(batch_size=batch_size)
|
|
|
|
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
|
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
|
return Actor(**actor_kwargs).to(self.device)
|
|
|
|
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
|
|
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
|
return ContinuousCritic(**critic_kwargs).to(self.device)
|
|
|
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
|
return self._predict(obs, deterministic=deterministic)
|
|
|
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
|
return self.actor(observation, deterministic)
|
|
|
|
def set_training_mode(self, mode: bool) -> None:
|
|
"""
|
|
Put the policy in either training or evaluation mode.
|
|
|
|
This affects certain modules, such as batch normalisation and dropout.
|
|
|
|
:param mode: if true, set to training mode, else set to evaluation mode
|
|
"""
|
|
self.actor.set_training_mode(mode)
|
|
self.critic.set_training_mode(mode)
|
|
self.training = mode
|
|
|
|
SACMlpPolicy = SACPolicy |