Renamed TRL_PG to PPO

This commit is contained in:
Dominik Moritz Roth 2022-07-13 19:51:33 +02:00
parent 1706bea571
commit b1ed9fc2b8
7 changed files with 1840 additions and 9 deletions

View File

@ -0,0 +1,2 @@
from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from ..trl_pg.trl_pg import TRL_PG

View File

@ -0,0 +1,514 @@
import collections
import warnings
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from torch import nn
import math
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
Distribution,
MultiCategoricalDistribution,
StateDependentNoiseDistribution,
)
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
MlpExtractor,
NatureCNN,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
)
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
class ActorCriticPolicy(BasePolicy):
"""
Code stolen from SB3
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
# TODO: Allow passing of dist_kwargs into dist
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=squash_output,
)
# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
self.net_arch = net_arch
self.activation_fn = activation_fn
self.ortho_init = ortho_init
self.features_extractor = features_extractor_class(
self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.normalize_images = normalize_images
self.log_std_init = log_std_init
dist_kwargs = None
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
"full_std": full_std,
"squash_output": squash_output,
"use_expln": use_expln,
"learn_features": False,
}
if sde_net_arch is not None:
warnings.warn(
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
self.use_sde = use_sde
self.dist_kwargs = dist_kwargs
# Action distribution
self.action_dist = make_proba_distribution(
action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
self._build(lr_schedule)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(
lambda: None)
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
full_std=default_none_kwargs["full_std"],
use_expln=default_none_kwargs["use_expln"],
# dummy lr schedule, not needed for loading policy alone
lr_schedule=self._dummy_schedule,
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def reset_noise(self, n_envs: int = 1) -> None:
"""
Sample new weights for the exploration matrix.
:param n_envs:
"""
assert isinstance(
self.action_dist, StateDependentNoiseDistribution) or isinstance(
self.action_dist, UniversalGaussianDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(
self.features_dim,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
device=self.device,
)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the networks and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self._build_mlp_extractor()
latent_dim_pi = self.mlp_extractor.latent_dim_pi
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
self.action_net = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, UniversalGaussianDistribution):
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, std_init=math.exp(
self.log_std_init)
)
else:
raise NotImplementedError(
f"Unsupported distribution '{self.action_dist}'.")
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
# with small initial weight for the output
if self.ortho_init:
# TODO: check for features_extractor
# Values from stable-baselines.
# features_extractor/mlp values are
# originally from openai/baselines (default gains/init_scales).
module_gains = {
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
self.action_net: 0.01,
self.value_net: 1,
}
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(
self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:return: Action distribution
"""
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
elif isinstance(self.action_dist, UniversalGaussianDistribution):
chol = self.chol_net(latent_pi)
self.chol = chol
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
else:
raise ValueError("Invalid action distribution")
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs:
:param actions:
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
def get_distribution(self, obs: th.Tensor) -> Distribution:
"""
Get the current policy distribution given the observations.
:param obs:
:return: the action distribution.
"""
features = self.extract_features(obs)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)
def predict_values(self, obs: th.Tensor) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
:param obs:
:return: the estimated values.
"""
features = self.extract_features(obs)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
class ActorCriticCnnPolicy(ActorCriticPolicy):
"""
Code stolen from SB3
CNN policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class MultiInputActorCriticPolicy(ActorCriticPolicy):
"""
Code stolen from SB3
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space (Tuple)
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy

View File

@ -0,0 +1,391 @@
import warnings
from typing import Any, Dict, Optional, Type, Union, NamedTuple
import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.vec_env import VecNormalize
from ..misc.distTools import new_dist_like
from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
from ..projections.w2_projection_layer import WassersteinProjectionLayer
from ..projections.kl_projection_layer import KLProjectionLayer
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
"""
Differential Trust Region Layer (TRL) for Policy Gradient (PG)
Paper: https://arxiv.org/abs/2101.09207
Code: This implementation borrows (/steals most) code from SB3's PPO implementation https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py
The implementation of the TRL-specific parts borrows from https://github.com/boschresearch/trust-region-layers/blob/main/trust_region_projections/algorithms/pg/pg.py (Stolen from Fabian's Code (Public Version))
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
See https://github.com/pytorch/pytorch/issues/29372
:param batch_size: Minibatch size
:param n_epochs: Number of epoch when optimizing the surrogate loss
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param clip_range: Clipping parameter, it can be a function of the current progress
remaining (from 1 to 0).
:param clip_range_vf: Clipping parameter for the value function,
it can be a function of the current progress remaining (from 1 to 0).
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param normalize_advantage: Whether to normalize or not the advantage
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
# 213 (cf https://github.com/hill-a/stable-baselines/issues/213)
see issue
By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param projection: What kind of Projection to use
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 2048,
batch_size: int = 64,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
target_kl: Optional[float] = None,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
# Different from PPO:
#projection: BaseProjectionLayer = KLProjectionLayer(),
#projection: BaseProjectionLayer = WassersteinProjectionLayer(),
#projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
projection: BaseProjectionLayer = BaseProjectionLayer(),
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate=learning_rate,
n_steps=n_steps,
gamma=gamma,
gae_lambda=gae_lambda,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
create_eval_env=create_eval_env,
seed=seed,
_init_setup_model=False,
supported_action_spaces=(
spaces.Box,
# spaces.Discrete,
# spaces.MultiDiscrete,
# spaces.MultiBinary,
),
)
# Sanity check, otherwise it will lead to noisy gradient and NaN
# because of the advantage normalization
if normalize_advantage:
assert (
batch_size > 1
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
if self.env is not None:
# Check that `n_steps * n_envs > 1` to avoid NaN
# when doing advantage normalization
buffer_size = self.env.num_envs * self.n_steps
assert (
buffer_size > 1
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
# Check that the rollout buffer size is a multiple of the mini-batch size
untruncated_batches = buffer_size // batch_size
if buffer_size % batch_size > 0:
warnings.warn(
f"You have specified a mini-batch size of {batch_size},"
f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
f" after every {untruncated_batches} untruncated mini-batches,"
f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
)
self.batch_size = batch_size
self.n_epochs = n_epochs
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
# Different from PPO:
self.projection = projection
self._global_steps = 0
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super()._setup_model()
# Initialize schedules for policy/value clipping
self.clip_range = get_schedule_fn(self.clip_range)
if self.clip_range_vf is not None:
if isinstance(self.clip_range_vf, (float, int)):
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining)
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(
self._current_progress_remaining)
surrogate_losses = []
entropy_losses = []
trust_region_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
continue_training = True
# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
# This is new compared to PPO.
# Calculating the TR-Projections we need to know the step number
self._global_steps += 1
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
self.policy.reset_noise(self.batch_size)
# Different from PPO
# TRL-Projection-Action:
pol = self.policy
features = pol.extract_features(rollout_data.observations)
latent_pi, latent_vf = pol.mlp_extractor(features)
p = pol._get_action_dist_from_latent(latent_pi)
p_dist = p.distribution
q_dist = new_dist_like(
p_dist, rollout_data.means, rollout_data.stds)
proj_p = self.projection(p_dist, q_dist, self._global_steps)
log_prob = proj_p.log_prob(actions).sum(dim=1)
values = self.policy.value_net(latent_vf)
entropy = proj_p.entropy()
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()
) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
# clipped surrogate loss
surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * \
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
surrogate_loss = - \
th.min(surrogate_loss_1, surrogate_loss_2).mean()
surrogate_losses.append(surrogate_loss.item())
clip_fraction = th.mean(
(th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
# No clipping
values_pred = values
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
)
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
trust_region_loss = self.projection.get_trust_region_loss(
p, proj_p)
trust_region_losses.append(trust_region_loss.item())
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + trust_region_loss
pg_losses.append(policy_loss.item())
loss = policy_loss + self.vf_coef * value_loss
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean(
(th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(
f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
if not continue_training:
break
self._n_updates += self.n_epochs
explained_var = explained_variance(
self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/trust_region_loss",
np.mean(trust_region_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var)
if hasattr(self.policy, "log_std"):
self.logger.record(
"train/std", th.exp(self.policy.log_std).mean().item())
if hasattr(self.policy, "chol"):
self.logger.record(
"train/std", th.mean(th.diagonal(self.policy.chol, dim1=-2, dim2=-1)).mean().item())
self.logger.record("train/n_updates",
self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "TRL_PG",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "TRL_PG":
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
)

View File

@ -0,0 +1,2 @@
from sb3_trl.trl_sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_trl.trl_sac.trl_sac import TRL_SAC

View File

@ -0,0 +1,516 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from torch import nn
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class Actor(BasePolicy):
"""
Actor network (policy) for SAC.
:param observation_space: Obervation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE.
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
squash_output=True,
)
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
self.log_std_init = log_std_init
self.sde_net_arch = sde_net_arch
self.use_expln = use_expln
self.full_std = full_std
self.clip_mean = clip_mean
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde:
self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
)
self.mu, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
)
# Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean]
if clip_mean > 0.0:
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
else:
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
full_std=self.full_std,
use_expln=self.use_expln,
features_extractor=self.features_extractor,
clip_mean=self.clip_mean,
)
)
return data
def get_std(self) -> th.Tensor:
"""
Retrieve the standard deviation of the action distribution.
Only useful when using gSDE.
It corresponds to ``th.exp(log_std)`` in the normal case,
but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc).
:return:
"""
msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
:param obs:
:return:
Mean, standard deviation and optional keyword arguments.
"""
features = self.extract_features(obs)
latent_pi = self.latent_pi(features)
mean_actions = self.mu(latent_pi)
if self.use_sde:
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi)
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
class SACPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
)
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [256, 256]
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
"use_expln": use_expln,
"clip_mean": clip_mean,
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update(
{
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
)
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Do not optimize the shared features extractor with the critic loss
# otherwise, there are gradient computation issues
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
else:
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = self.critic.parameters()
# Critic target should not share the features extractor with critic
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
# Target networks should always be in eval mode
self.critic_target.set_training_mode(False)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"],
n_critics=self.critic_kwargs["n_critics"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode
MlpPolicy = SACPolicy
class CnnPolicy(SACPolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
sde_net_arch,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)
class MultiInputPolicy(SACPolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param sde_net_arch: Network architecture for extracting features
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
sde_net_arch,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)

View File

@ -0,0 +1,406 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
from ..misc.distTools import new_dist_like
from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
from ..projections.w2_projection_layer import WassersteinProjectionLayer
from ..projections.kl_projection_layer import KLProjectionLayer
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
"""
Trust Region Layers (TRL) based on SAC (Soft Actor Critic)
This implementation is almost a 1:1-copy of the sb3-code for SAC.
Only minor changes have been made to implement Differential Trust Region Layers
Description from original SAC implementation:
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
(https://github.com/rail-berkeley/softlearning/)
and from Stable Baselines (https://github.com/hill-a/stable-baselines)
Paper: https://arxiv.org/abs/1801.01290
Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
Note: we use double q target and not value target as discussed
in https://github.com/hill-a/stable-baselines/issues/270
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param ent_coef: Entropy regularization coefficient. (Equivalent to
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
:param target_update_interval: update the target network every ``target_network_update_freq``
gradient steps.
:param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[SACPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Union[str, float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
# Different from SAC:
projection: BaseProjectionLayer = KLProjectionLayer(),
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
create_eval_env=create_eval_env,
seed=seed,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Box),
support_multi_env=True,
)
raise Exception('TRL_SAC is not yet implemented')
self.target_entropy = target_entropy
self.log_ent_coef = None # type: Optional[th.Tensor]
# Entropy coefficient / Entropy temperature
# Inverse of the reward scale
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = None
# Different from SAC:
self.projection = projection
self._global_steps = 0
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = - \
np.prod(self.env.action_space.shape).astype(np.float32)
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)
# The entropy coefficient or entropy can be learned automatically
# see Automating Entropy Adjustment for Maximum Entropy RL section
# of https://arxiv.org/abs/1812.05905
if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
# Default initial value of ent_coef when learned
init_value = 1.0
if "_" in self.ent_coef:
init_value = float(self.ent_coef.split("_")[1])
assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
self.log_ent_coef = th.log(
th.ones(1, device=self.device) * init_value).requires_grad_(True)
self.ent_coef_optimizer = th.optim.Adam(
[self.log_ent_coef], lr=self.lr_schedule(1))
else:
# Force conversion to float
# this will throw an error if a malformed string (different from 'auto')
# is passed
self.ent_coef_tensor = th.tensor(
float(self.ent_coef)).to(self.device)
def _create_aliases(self) -> None:
self.actor = self.policy.actor
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
optimizers += [self.ent_coef_optimizer]
# Update learning rate according to lr schedule
self._update_learning_rate(optimizers)
ent_coef_losses, ent_coefs = [], []
actor_losses, critic_losses = [], []
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(
batch_size, env=self._vec_normalize_env)
# This is new compared to SAC.
# Calculating the TR-Projections we need to know the step number
self._global_steps += 1
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
self.actor.reset_noise()
#################
# Orig Code:
# Action by the current actor for the sampled state
# actions_pi, log_prob = self.actor.action_log_prob(
# replay_data.observations)
# log_prob = log_prob.reshape(-1, 1)
act = self.actor
features = act.extract_features(replay_data.observations)
latent_pi = act.latent_pi(features)
mean_actions = act.mu(latent_pi)
# TODO: Allow contextual covariance with sde
if self.use_sde:
log_std = act.log_std
else:
# Unstructured exploration (Original implementation)
log_std = act.log_std(latent_pi)
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
act_dist = self.action_dist
# internal A
if self.use_sde:
actions_pi = self.actions_from_params(
mean_actions, log_std, latent_pi) # latent_pi = latent_sde
else:
actions_pi = act_dist.actions_from_params(
mean_actions, log_std)
p_dist = self.action_dist.distribution
q_dist = new_dist_like(
p_dist, replay_data.means, replay_data.stds)
proj_p = self.projection(p_dist, q_dist, self._global_steps)
log_prob = proj_p.log_prob(actions_pi)
log_prob = log_prob.reshape(-1, 1)
####################
ent_coef_loss = None
if self.ent_coef_optimizer is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = - \
(self.log_ent_coef * (log_prob +
self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor
ent_coefs.append(ent_coef.item())
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
with th.no_grad():
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(
replay_data.next_observations)
# Compute the next Q values: min over all critics targets
next_q_values = th.cat(self.critic_target(
replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
# add entropy term
next_q_values = next_q_values - \
ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
target_q_values = replay_data.rewards + \
(1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates for each critic network
# using action from the replay buffer
current_q_values = self.critic(
replay_data.observations, replay_data.actions)
projection_loss = th.zeros(1)
# Compute critic loss
critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values)
for current_q in current_q_values)
critic_loss = critic_loss_raw + projection_loss
critic_losses.append(critic_loss.item())
# Optimize the critic
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
# Mean over all critic networks
q_values_pi = th.cat(self.critic(
replay_data.observations, actions_pi), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(),
self.critic_target.parameters(), self.tau)
self._n_updates += gradient_steps
self.logger.record("train/n_updates",
self._n_updates, exclude="tensorboard")
self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses))
self.logger.record("train/critic_loss", np.mean(critic_losses))
if len(ent_coef_losses) > 0:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "SAC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
)
def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
if self.ent_coef_optimizer is not None:
saved_pytorch_variables = ["log_ent_coef"]
state_dicts.append("ent_coef_optimizer")
else:
saved_pytorch_variables = ["ent_coef_tensor"]
return state_dicts, saved_pytorch_variables

18
test.py
View File

@ -6,12 +6,12 @@ import os
import time import time
import datetime import datetime
from stable_baselines3 import SAC, PPO, A2C from stable_baselines3 import
from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
from metastable_baselines.trl_pg import TRL_PG from metastable_baselines.ppo import PPO
from metastable_baselines.trl_pg.policies import MlpPolicy from metastable_baselines.ppo.policies import MlpPolicy
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
import columbus import columbus
@ -22,7 +22,7 @@ root_path = '.'
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0): def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
env = gym.make(env_name) env = gym.make(env_name)
use_sde = False use_sde = False
ppo = TRL_PG( ppo = PPO(
MlpPolicy, MlpPolicy,
env, env,
verbose=0, verbose=0,
@ -37,13 +37,13 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
use_sde=use_sde, # False use_sde=use_sde, # False
clip_range=0.2, clip_range=0.2,
) )
trl_pg = TRL_PG( trl_frob = PPO(
MlpPolicy, MlpPolicy,
env, env,
projection=FrobeniusProjectionLayer(), projection=FrobeniusProjectionLayer(),
verbose=0, verbose=0,
tensorboard_log=root_path+"/logs_tb/"+env_name + tensorboard_log=root_path+"/logs_tb/"+env_name +
"/trl_pg"+(['', '_sde'][use_sde])+"/", "/trl_frob"+(['', '_sde'][use_sde])+"/",
learning_rate=3e-4, learning_rate=3e-4,
gamma=0.99, gamma=0.99,
gae_lambda=0.95, gae_lambda=0.95,
@ -54,12 +54,12 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
clip_range=2, # 0.2 clip_range=2, # 0.2
) )
print('TRL_PG:')
testModel(trl_pg, timesteps, showRes,
saveModel, n_eval_episodes)
print('PPO:') print('PPO:')
testModel(ppo, timesteps, showRes, testModel(ppo, timesteps, showRes,
saveModel, n_eval_episodes) saveModel, n_eval_episodes)
print('TRL_frob:')
testModel(trl_frob, timesteps, showRes,
saveModel, n_eval_episodes)
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16): def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):