Renamed TRL_PG to PPO
This commit is contained in:
parent
1706bea571
commit
b1ed9fc2b8
2
metastable_baselines/ppo/__init__.py
Normal file
2
metastable_baselines/ppo/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
from ..trl_pg.trl_pg import TRL_PG
|
514
metastable_baselines/ppo/policies.py
Normal file
514
metastable_baselines/ppo/policies.py
Normal file
@ -0,0 +1,514 @@
|
|||||||
|
import collections
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from stable_baselines3.common.distributions import (
|
||||||
|
BernoulliDistribution,
|
||||||
|
CategoricalDistribution,
|
||||||
|
DiagGaussianDistribution,
|
||||||
|
Distribution,
|
||||||
|
MultiCategoricalDistribution,
|
||||||
|
StateDependentNoiseDistribution,
|
||||||
|
)
|
||||||
|
from stable_baselines3.common.torch_layers import (
|
||||||
|
BaseFeaturesExtractor,
|
||||||
|
CombinedExtractor,
|
||||||
|
FlattenExtractor,
|
||||||
|
MlpExtractor,
|
||||||
|
NatureCNN,
|
||||||
|
)
|
||||||
|
from stable_baselines3.common.type_aliases import Schedule
|
||||||
|
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
from stable_baselines3.common.torch_layers import (
|
||||||
|
BaseFeaturesExtractor,
|
||||||
|
CombinedExtractor,
|
||||||
|
FlattenExtractor,
|
||||||
|
NatureCNN,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||||
|
|
||||||
|
|
||||||
|
class ActorCriticPolicy(BasePolicy):
|
||||||
|
"""
|
||||||
|
Code stolen from SB3
|
||||||
|
|
||||||
|
Policy class for actor-critic algorithms (has both policy and value prediction).
|
||||||
|
Used by A2C, PPO and the likes.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param net_arch: The specification of the policy and value networks.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param ortho_init: Whether to use or not orthogonal initialization
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
|
for the std instead of only (n_features,) when using gSDE
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param squash_output: Whether to squash the output using a tanh function,
|
||||||
|
this allows to ensure boundaries when using gSDE.
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
|
to pass to the features extractor.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Allow passing of dist_kwargs into dist
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
ortho_init: bool = True,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = 0.0,
|
||||||
|
full_std: bool = True,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
squash_output: bool = False,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if optimizer_kwargs is None:
|
||||||
|
optimizer_kwargs = {}
|
||||||
|
# Small values to avoid NaN in Adam optimizer
|
||||||
|
if optimizer_class == th.optim.Adam:
|
||||||
|
optimizer_kwargs["eps"] = 1e-5
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
optimizer_class=optimizer_class,
|
||||||
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
squash_output=squash_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default network architecture, from stable-baselines
|
||||||
|
if net_arch is None:
|
||||||
|
if features_extractor_class == NatureCNN:
|
||||||
|
net_arch = []
|
||||||
|
else:
|
||||||
|
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
|
||||||
|
|
||||||
|
self.net_arch = net_arch
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.ortho_init = ortho_init
|
||||||
|
|
||||||
|
self.features_extractor = features_extractor_class(
|
||||||
|
self.observation_space, **self.features_extractor_kwargs)
|
||||||
|
self.features_dim = self.features_extractor.features_dim
|
||||||
|
|
||||||
|
self.normalize_images = normalize_images
|
||||||
|
self.log_std_init = log_std_init
|
||||||
|
dist_kwargs = None
|
||||||
|
# Keyword arguments for gSDE distribution
|
||||||
|
if use_sde:
|
||||||
|
dist_kwargs = {
|
||||||
|
"full_std": full_std,
|
||||||
|
"squash_output": squash_output,
|
||||||
|
"use_expln": use_expln,
|
||||||
|
"learn_features": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if sde_net_arch is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||||
|
|
||||||
|
self.use_sde = use_sde
|
||||||
|
self.dist_kwargs = dist_kwargs
|
||||||
|
|
||||||
|
# Action distribution
|
||||||
|
self.action_dist = make_proba_distribution(
|
||||||
|
action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
||||||
|
|
||||||
|
self._build(lr_schedule)
|
||||||
|
|
||||||
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
|
default_none_kwargs = self.dist_kwargs or collections.defaultdict(
|
||||||
|
lambda: None)
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
dict(
|
||||||
|
net_arch=self.net_arch,
|
||||||
|
activation_fn=self.activation_fn,
|
||||||
|
use_sde=self.use_sde,
|
||||||
|
log_std_init=self.log_std_init,
|
||||||
|
squash_output=default_none_kwargs["squash_output"],
|
||||||
|
full_std=default_none_kwargs["full_std"],
|
||||||
|
use_expln=default_none_kwargs["use_expln"],
|
||||||
|
# dummy lr schedule, not needed for loading policy alone
|
||||||
|
lr_schedule=self._dummy_schedule,
|
||||||
|
ortho_init=self.ortho_init,
|
||||||
|
optimizer_class=self.optimizer_class,
|
||||||
|
optimizer_kwargs=self.optimizer_kwargs,
|
||||||
|
features_extractor_class=self.features_extractor_class,
|
||||||
|
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def reset_noise(self, n_envs: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
Sample new weights for the exploration matrix.
|
||||||
|
|
||||||
|
:param n_envs:
|
||||||
|
"""
|
||||||
|
assert isinstance(
|
||||||
|
self.action_dist, StateDependentNoiseDistribution) or isinstance(
|
||||||
|
self.action_dist, UniversalGaussianDistribution), "reset_noise() is only available when using gSDE"
|
||||||
|
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||||
|
|
||||||
|
def _build_mlp_extractor(self) -> None:
|
||||||
|
"""
|
||||||
|
Create the policy and value networks.
|
||||||
|
Part of the layers can be shared.
|
||||||
|
"""
|
||||||
|
# Note: If net_arch is None and some features extractor is used,
|
||||||
|
# net_arch here is an empty list and mlp_extractor does not
|
||||||
|
# really contain any layers (acts like an identity module).
|
||||||
|
self.mlp_extractor = MlpExtractor(
|
||||||
|
self.features_dim,
|
||||||
|
net_arch=self.net_arch,
|
||||||
|
activation_fn=self.activation_fn,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build(self, lr_schedule: Schedule) -> None:
|
||||||
|
"""
|
||||||
|
Create the networks and the optimizer.
|
||||||
|
|
||||||
|
:param lr_schedule: Learning rate schedule
|
||||||
|
lr_schedule(1) is the initial learning rate
|
||||||
|
"""
|
||||||
|
self._build_mlp_extractor()
|
||||||
|
|
||||||
|
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
||||||
|
|
||||||
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||||
|
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
||||||
|
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
|
||||||
|
)
|
||||||
|
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||||
|
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
||||||
|
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
|
||||||
|
)
|
||||||
|
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
|
||||||
|
self.action_net = self.action_dist.proba_distribution_net(
|
||||||
|
latent_dim=latent_dim_pi)
|
||||||
|
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||||
|
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
|
||||||
|
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, std_init=math.exp(
|
||||||
|
self.log_std_init)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported distribution '{self.action_dist}'.")
|
||||||
|
|
||||||
|
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
|
||||||
|
# Init weights: use orthogonal initialization
|
||||||
|
# with small initial weight for the output
|
||||||
|
if self.ortho_init:
|
||||||
|
# TODO: check for features_extractor
|
||||||
|
# Values from stable-baselines.
|
||||||
|
# features_extractor/mlp values are
|
||||||
|
# originally from openai/baselines (default gains/init_scales).
|
||||||
|
module_gains = {
|
||||||
|
self.features_extractor: np.sqrt(2),
|
||||||
|
self.mlp_extractor: np.sqrt(2),
|
||||||
|
self.action_net: 0.01,
|
||||||
|
self.value_net: 1,
|
||||||
|
}
|
||||||
|
for module, gain in module_gains.items():
|
||||||
|
module.apply(partial(self.init_weights, gain=gain))
|
||||||
|
|
||||||
|
# Setup optimizer with initial learning rate
|
||||||
|
self.optimizer = self.optimizer_class(
|
||||||
|
self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass in all the networks (actor and critic)
|
||||||
|
|
||||||
|
:param obs: Observation
|
||||||
|
:param deterministic: Whether to sample or use deterministic actions
|
||||||
|
:return: action, value and log probability of the action
|
||||||
|
"""
|
||||||
|
# Preprocess the observation if needed
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
|
# Evaluate the values for the given observations
|
||||||
|
values = self.value_net(latent_vf)
|
||||||
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
|
actions = distribution.get_actions(deterministic=deterministic)
|
||||||
|
log_prob = distribution.log_prob(actions)
|
||||||
|
return actions, values, log_prob
|
||||||
|
|
||||||
|
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
|
||||||
|
"""
|
||||||
|
Retrieve action distribution given the latent codes.
|
||||||
|
|
||||||
|
:param latent_pi: Latent code for the actor
|
||||||
|
:return: Action distribution
|
||||||
|
"""
|
||||||
|
mean_actions = self.action_net(latent_pi)
|
||||||
|
|
||||||
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||||
|
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
||||||
|
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||||
|
# Here mean_actions are the logits before the softmax
|
||||||
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||||
|
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||||
|
# Here mean_actions are the flattened logits
|
||||||
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||||
|
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||||
|
# Here mean_actions are the logits (before rounding to get the binary actions)
|
||||||
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||||
|
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||||
|
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
||||||
|
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||||
|
chol = self.chol_net(latent_pi)
|
||||||
|
self.chol = chol
|
||||||
|
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid action distribution")
|
||||||
|
|
||||||
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Get the action according to the policy for a given observation.
|
||||||
|
|
||||||
|
:param observation:
|
||||||
|
:param deterministic: Whether to use stochastic or deterministic actions
|
||||||
|
:return: Taken action according to the policy
|
||||||
|
"""
|
||||||
|
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||||
|
|
||||||
|
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||||
|
"""
|
||||||
|
Evaluate actions according to the current policy,
|
||||||
|
given the observations.
|
||||||
|
|
||||||
|
:param obs:
|
||||||
|
:param actions:
|
||||||
|
:return: estimated value, log likelihood of taking those actions
|
||||||
|
and entropy of the action distribution.
|
||||||
|
"""
|
||||||
|
# Preprocess the observation if needed
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
|
log_prob = distribution.log_prob(actions)
|
||||||
|
values = self.value_net(latent_vf)
|
||||||
|
return values, log_prob, distribution.entropy()
|
||||||
|
|
||||||
|
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
||||||
|
"""
|
||||||
|
Get the current policy distribution given the observations.
|
||||||
|
|
||||||
|
:param obs:
|
||||||
|
:return: the action distribution.
|
||||||
|
"""
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||||
|
return self._get_action_dist_from_latent(latent_pi)
|
||||||
|
|
||||||
|
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Get the estimated values according to the current policy given the observations.
|
||||||
|
|
||||||
|
:param obs:
|
||||||
|
:return: the estimated values.
|
||||||
|
"""
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||||
|
return self.value_net(latent_vf)
|
||||||
|
|
||||||
|
|
||||||
|
class ActorCriticCnnPolicy(ActorCriticPolicy):
|
||||||
|
"""
|
||||||
|
Code stolen from SB3
|
||||||
|
|
||||||
|
CNN policy class for actor-critic algorithms (has both policy and value prediction).
|
||||||
|
Used by A2C, PPO and the likes.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param net_arch: The specification of the policy and value networks.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param ortho_init: Whether to use or not orthogonal initialization
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
|
for the std instead of only (n_features,) when using gSDE
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param squash_output: Whether to squash the output using a tanh function,
|
||||||
|
this allows to ensure boundaries when using gSDE.
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
|
to pass to the features extractor.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
ortho_init: bool = True,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = 0.0,
|
||||||
|
full_std: bool = True,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
squash_output: bool = False,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
lr_schedule,
|
||||||
|
net_arch,
|
||||||
|
activation_fn,
|
||||||
|
ortho_init,
|
||||||
|
use_sde,
|
||||||
|
log_std_init,
|
||||||
|
full_std,
|
||||||
|
sde_net_arch,
|
||||||
|
use_expln,
|
||||||
|
squash_output,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
normalize_images,
|
||||||
|
optimizer_class,
|
||||||
|
optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
||||||
|
"""
|
||||||
|
Code stolen from SB3
|
||||||
|
|
||||||
|
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
|
||||||
|
Used by A2C, PPO and the likes.
|
||||||
|
|
||||||
|
:param observation_space: Observation space (Tuple)
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param net_arch: The specification of the policy and value networks.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param ortho_init: Whether to use or not orthogonal initialization
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
|
for the std instead of only (n_features,) when using gSDE
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param squash_output: Whether to squash the output using a tanh function,
|
||||||
|
this allows to ensure boundaries when using gSDE.
|
||||||
|
:param features_extractor_class: Uses the CombinedExtractor
|
||||||
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
|
to pass to the feature extractor.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Dict,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||||
|
ortho_init: bool = True,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = 0.0,
|
||||||
|
full_std: bool = True,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
squash_output: bool = False,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
lr_schedule,
|
||||||
|
net_arch,
|
||||||
|
activation_fn,
|
||||||
|
ortho_init,
|
||||||
|
use_sde,
|
||||||
|
log_std_init,
|
||||||
|
full_std,
|
||||||
|
sde_net_arch,
|
||||||
|
use_expln,
|
||||||
|
squash_output,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
normalize_images,
|
||||||
|
optimizer_class,
|
||||||
|
optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MlpPolicy = ActorCriticPolicy
|
||||||
|
CnnPolicy = ActorCriticCnnPolicy
|
||||||
|
MultiInputPolicy = MultiInputActorCriticPolicy
|
391
metastable_baselines/ppo/ppo.py
Normal file
391
metastable_baselines/ppo/ppo.py
Normal file
@ -0,0 +1,391 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from gym import spaces
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||||
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
from stable_baselines3.common.utils import obs_as_tensor
|
||||||
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
|
from ..misc.distTools import new_dist_like
|
||||||
|
|
||||||
|
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||||
|
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||||
|
from ..projections.w2_projection_layer import WassersteinProjectionLayer
|
||||||
|
from ..projections.kl_projection_layer import KLProjectionLayer
|
||||||
|
|
||||||
|
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||||
|
|
||||||
|
|
||||||
|
class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||||
|
"""
|
||||||
|
Differential Trust Region Layer (TRL) for Policy Gradient (PG)
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/2101.09207
|
||||||
|
Code: This implementation borrows (/steals most) code from SB3's PPO implementation https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py
|
||||||
|
The implementation of the TRL-specific parts borrows from https://github.com/boschresearch/trust-region-layers/blob/main/trust_region_projections/algorithms/pg/pg.py (Stolen from Fabian's Code (Public Version))
|
||||||
|
|
||||||
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||||
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||||
|
:param learning_rate: The learning rate, it can be a function
|
||||||
|
of the current progress remaining (from 1 to 0)
|
||||||
|
:param n_steps: The number of steps to run for each environment per update
|
||||||
|
(i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
|
||||||
|
NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
|
||||||
|
See https://github.com/pytorch/pytorch/issues/29372
|
||||||
|
:param batch_size: Minibatch size
|
||||||
|
:param n_epochs: Number of epoch when optimizing the surrogate loss
|
||||||
|
:param gamma: Discount factor
|
||||||
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||||
|
:param clip_range: Clipping parameter, it can be a function of the current progress
|
||||||
|
remaining (from 1 to 0).
|
||||||
|
:param clip_range_vf: Clipping parameter for the value function,
|
||||||
|
it can be a function of the current progress remaining (from 1 to 0).
|
||||||
|
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||||
|
no clipping will be done on the value function.
|
||||||
|
IMPORTANT: this clipping depends on the reward scaling.
|
||||||
|
:param normalize_advantage: Whether to normalize or not the advantage
|
||||||
|
:param ent_coef: Entropy coefficient for the loss calculation
|
||||||
|
:param vf_coef: Value function coefficient for the loss calculation
|
||||||
|
:param max_grad_norm: The maximum value for the gradient clipping
|
||||||
|
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
||||||
|
instead of action noise exploration (default: False)
|
||||||
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||||
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
|
:param target_kl: Limit the KL divergence between updates,
|
||||||
|
because the clipping is not enough to prevent large update
|
||||||
|
# 213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||||
|
see issue
|
||||||
|
By default, there is no limit on the kl div.
|
||||||
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
|
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||||
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
|
:param seed: Seed for the pseudo random generators
|
||||||
|
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||||
|
Setting it to auto, the code will be run on the GPU if possible.
|
||||||
|
:param projection: What kind of Projection to use
|
||||||
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
|
"MlpPolicy": ActorCriticPolicy,
|
||||||
|
"CnnPolicy": ActorCriticCnnPolicy,
|
||||||
|
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: Union[str, Type[ActorCriticPolicy]],
|
||||||
|
env: Union[GymEnv, str],
|
||||||
|
learning_rate: Union[float, Schedule] = 3e-4,
|
||||||
|
n_steps: int = 2048,
|
||||||
|
batch_size: int = 64,
|
||||||
|
n_epochs: int = 10,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
clip_range: Union[float, Schedule] = 0.2,
|
||||||
|
clip_range_vf: Union[None, float, Schedule] = None,
|
||||||
|
normalize_advantage: bool = True,
|
||||||
|
ent_coef: float = 0.0,
|
||||||
|
vf_coef: float = 0.5,
|
||||||
|
max_grad_norm: float = 0.5,
|
||||||
|
use_sde: bool = False,
|
||||||
|
sde_sample_freq: int = -1,
|
||||||
|
target_kl: Optional[float] = None,
|
||||||
|
tensorboard_log: Optional[str] = None,
|
||||||
|
create_eval_env: bool = False,
|
||||||
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
|
||||||
|
# Different from PPO:
|
||||||
|
#projection: BaseProjectionLayer = KLProjectionLayer(),
|
||||||
|
#projection: BaseProjectionLayer = WassersteinProjectionLayer(),
|
||||||
|
#projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
|
||||||
|
projection: BaseProjectionLayer = BaseProjectionLayer(),
|
||||||
|
|
||||||
|
|
||||||
|
_init_setup_model: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
policy,
|
||||||
|
env,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
n_steps=n_steps,
|
||||||
|
gamma=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
ent_coef=ent_coef,
|
||||||
|
vf_coef=vf_coef,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
use_sde=use_sde,
|
||||||
|
sde_sample_freq=sde_sample_freq,
|
||||||
|
tensorboard_log=tensorboard_log,
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
verbose=verbose,
|
||||||
|
device=device,
|
||||||
|
create_eval_env=create_eval_env,
|
||||||
|
seed=seed,
|
||||||
|
_init_setup_model=False,
|
||||||
|
supported_action_spaces=(
|
||||||
|
spaces.Box,
|
||||||
|
# spaces.Discrete,
|
||||||
|
# spaces.MultiDiscrete,
|
||||||
|
# spaces.MultiBinary,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanity check, otherwise it will lead to noisy gradient and NaN
|
||||||
|
# because of the advantage normalization
|
||||||
|
if normalize_advantage:
|
||||||
|
assert (
|
||||||
|
batch_size > 1
|
||||||
|
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
|
||||||
|
|
||||||
|
if self.env is not None:
|
||||||
|
# Check that `n_steps * n_envs > 1` to avoid NaN
|
||||||
|
# when doing advantage normalization
|
||||||
|
buffer_size = self.env.num_envs * self.n_steps
|
||||||
|
assert (
|
||||||
|
buffer_size > 1
|
||||||
|
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
|
||||||
|
# Check that the rollout buffer size is a multiple of the mini-batch size
|
||||||
|
untruncated_batches = buffer_size // batch_size
|
||||||
|
if buffer_size % batch_size > 0:
|
||||||
|
warnings.warn(
|
||||||
|
f"You have specified a mini-batch size of {batch_size},"
|
||||||
|
f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
|
||||||
|
f" after every {untruncated_batches} untruncated mini-batches,"
|
||||||
|
f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
|
||||||
|
f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
|
||||||
|
f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
|
||||||
|
)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.n_epochs = n_epochs
|
||||||
|
self.clip_range = clip_range
|
||||||
|
self.clip_range_vf = clip_range_vf
|
||||||
|
self.normalize_advantage = normalize_advantage
|
||||||
|
self.target_kl = target_kl
|
||||||
|
|
||||||
|
# Different from PPO:
|
||||||
|
self.projection = projection
|
||||||
|
self._global_steps = 0
|
||||||
|
|
||||||
|
if _init_setup_model:
|
||||||
|
self._setup_model()
|
||||||
|
|
||||||
|
def _setup_model(self) -> None:
|
||||||
|
super()._setup_model()
|
||||||
|
|
||||||
|
# Initialize schedules for policy/value clipping
|
||||||
|
self.clip_range = get_schedule_fn(self.clip_range)
|
||||||
|
if self.clip_range_vf is not None:
|
||||||
|
if isinstance(self.clip_range_vf, (float, int)):
|
||||||
|
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
|
||||||
|
|
||||||
|
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||||
|
|
||||||
|
def train(self) -> None:
|
||||||
|
"""
|
||||||
|
Update policy using the currently gathered rollout buffer.
|
||||||
|
"""
|
||||||
|
# Switch to train mode (this affects batch norm / dropout)
|
||||||
|
self.policy.set_training_mode(True)
|
||||||
|
# Update optimizer learning rate
|
||||||
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
|
# Compute current clip range
|
||||||
|
clip_range = self.clip_range(self._current_progress_remaining)
|
||||||
|
# Optional: clip range for the value function
|
||||||
|
if self.clip_range_vf is not None:
|
||||||
|
clip_range_vf = self.clip_range_vf(
|
||||||
|
self._current_progress_remaining)
|
||||||
|
|
||||||
|
surrogate_losses = []
|
||||||
|
entropy_losses = []
|
||||||
|
trust_region_losses = []
|
||||||
|
pg_losses, value_losses = [], []
|
||||||
|
clip_fractions = []
|
||||||
|
|
||||||
|
continue_training = True
|
||||||
|
|
||||||
|
# train for n_epochs epochs
|
||||||
|
for epoch in range(self.n_epochs):
|
||||||
|
approx_kl_divs = []
|
||||||
|
# Do a complete pass on the rollout buffer
|
||||||
|
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||||
|
# This is new compared to PPO.
|
||||||
|
# Calculating the TR-Projections we need to know the step number
|
||||||
|
self._global_steps += 1
|
||||||
|
|
||||||
|
actions = rollout_data.actions
|
||||||
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
|
# Convert discrete action from float to long
|
||||||
|
actions = rollout_data.actions.long().flatten()
|
||||||
|
|
||||||
|
# Re-sample the noise matrix because the log_std has changed
|
||||||
|
if self.use_sde:
|
||||||
|
self.policy.reset_noise(self.batch_size)
|
||||||
|
|
||||||
|
# Different from PPO
|
||||||
|
# TRL-Projection-Action:
|
||||||
|
pol = self.policy
|
||||||
|
features = pol.extract_features(rollout_data.observations)
|
||||||
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
|
p_dist = p.distribution
|
||||||
|
q_dist = new_dist_like(
|
||||||
|
p_dist, rollout_data.means, rollout_data.stds)
|
||||||
|
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||||
|
log_prob = proj_p.log_prob(actions).sum(dim=1)
|
||||||
|
values = self.policy.value_net(latent_vf)
|
||||||
|
entropy = proj_p.entropy()
|
||||||
|
|
||||||
|
values = values.flatten()
|
||||||
|
# Normalize advantage
|
||||||
|
advantages = rollout_data.advantages
|
||||||
|
if self.normalize_advantage:
|
||||||
|
advantages = (advantages - advantages.mean()
|
||||||
|
) / (advantages.std() + 1e-8)
|
||||||
|
|
||||||
|
# ratio between old and new policy, should be one at the first iteration
|
||||||
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||||
|
|
||||||
|
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
||||||
|
# clipped surrogate loss
|
||||||
|
surrogate_loss_1 = advantages * ratio
|
||||||
|
surrogate_loss_2 = advantages * \
|
||||||
|
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||||
|
surrogate_loss = - \
|
||||||
|
th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
||||||
|
|
||||||
|
surrogate_losses.append(surrogate_loss.item())
|
||||||
|
|
||||||
|
clip_fraction = th.mean(
|
||||||
|
(th.abs(ratio - 1) > clip_range).float()).item()
|
||||||
|
clip_fractions.append(clip_fraction)
|
||||||
|
|
||||||
|
if self.clip_range_vf is None:
|
||||||
|
# No clipping
|
||||||
|
values_pred = values
|
||||||
|
else:
|
||||||
|
# Clip the different between old and new value
|
||||||
|
# NOTE: this depends on the reward scaling
|
||||||
|
values_pred = rollout_data.old_values + th.clamp(
|
||||||
|
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
|
||||||
|
)
|
||||||
|
# Value loss using the TD(gae_lambda) target
|
||||||
|
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
||||||
|
value_losses.append(value_loss.item())
|
||||||
|
|
||||||
|
# Entropy loss favor exploration
|
||||||
|
if entropy is None:
|
||||||
|
# Approximate entropy when no analytical form
|
||||||
|
entropy_loss = -th.mean(-log_prob)
|
||||||
|
else:
|
||||||
|
entropy_loss = -th.mean(entropy)
|
||||||
|
|
||||||
|
entropy_losses.append(entropy_loss.item())
|
||||||
|
|
||||||
|
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
||||||
|
trust_region_loss = self.projection.get_trust_region_loss(
|
||||||
|
p, proj_p)
|
||||||
|
|
||||||
|
trust_region_losses.append(trust_region_loss.item())
|
||||||
|
|
||||||
|
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + trust_region_loss
|
||||||
|
pg_losses.append(policy_loss.item())
|
||||||
|
|
||||||
|
loss = policy_loss + self.vf_coef * value_loss
|
||||||
|
|
||||||
|
# Calculate approximate form of reverse KL Divergence for early stopping
|
||||||
|
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
||||||
|
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
||||||
|
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||||
|
with th.no_grad():
|
||||||
|
log_ratio = log_prob - rollout_data.old_log_prob
|
||||||
|
approx_kl_div = th.mean(
|
||||||
|
(th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||||
|
approx_kl_divs.append(approx_kl_div)
|
||||||
|
|
||||||
|
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||||
|
continue_training = False
|
||||||
|
if self.verbose >= 1:
|
||||||
|
print(
|
||||||
|
f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Optimization step
|
||||||
|
self.policy.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
# Clip grad norm
|
||||||
|
th.nn.utils.clip_grad_norm_(
|
||||||
|
self.policy.parameters(), self.max_grad_norm)
|
||||||
|
self.policy.optimizer.step()
|
||||||
|
|
||||||
|
if not continue_training:
|
||||||
|
break
|
||||||
|
|
||||||
|
self._n_updates += self.n_epochs
|
||||||
|
explained_var = explained_variance(
|
||||||
|
self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
|
||||||
|
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||||
|
self.logger.record("train/trust_region_loss",
|
||||||
|
np.mean(trust_region_losses))
|
||||||
|
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||||
|
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||||
|
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||||
|
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
||||||
|
self.logger.record("train/loss", loss.item())
|
||||||
|
self.logger.record("train/explained_variance", explained_var)
|
||||||
|
if hasattr(self.policy, "log_std"):
|
||||||
|
self.logger.record(
|
||||||
|
"train/std", th.exp(self.policy.log_std).mean().item())
|
||||||
|
if hasattr(self.policy, "chol"):
|
||||||
|
self.logger.record(
|
||||||
|
"train/std", th.mean(th.diagonal(self.policy.chol, dim1=-2, dim2=-1)).mean().item())
|
||||||
|
|
||||||
|
self.logger.record("train/n_updates",
|
||||||
|
self._n_updates, exclude="tensorboard")
|
||||||
|
self.logger.record("train/clip_range", clip_range)
|
||||||
|
if self.clip_range_vf is not None:
|
||||||
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
|
||||||
|
def learn(
|
||||||
|
self,
|
||||||
|
total_timesteps: int,
|
||||||
|
callback: MaybeCallback = None,
|
||||||
|
log_interval: int = 1,
|
||||||
|
eval_env: Optional[GymEnv] = None,
|
||||||
|
eval_freq: int = -1,
|
||||||
|
n_eval_episodes: int = 5,
|
||||||
|
tb_log_name: str = "TRL_PG",
|
||||||
|
eval_log_path: Optional[str] = None,
|
||||||
|
reset_num_timesteps: bool = True,
|
||||||
|
) -> "TRL_PG":
|
||||||
|
|
||||||
|
return super().learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
callback=callback,
|
||||||
|
log_interval=log_interval,
|
||||||
|
eval_env=eval_env,
|
||||||
|
eval_freq=eval_freq,
|
||||||
|
n_eval_episodes=n_eval_episodes,
|
||||||
|
tb_log_name=tb_log_name,
|
||||||
|
eval_log_path=eval_log_path,
|
||||||
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
|
)
|
2
metastable_baselines/sac/__init__.py
Normal file
2
metastable_baselines/sac/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from sb3_trl.trl_sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
from sb3_trl.trl_sac.trl_sac import TRL_SAC
|
516
metastable_baselines/sac/policies.py
Normal file
516
metastable_baselines/sac/policies.py
Normal file
@ -0,0 +1,516 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import torch as th
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||||
|
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
|
||||||
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
|
from stable_baselines3.common.torch_layers import (
|
||||||
|
BaseFeaturesExtractor,
|
||||||
|
CombinedExtractor,
|
||||||
|
FlattenExtractor,
|
||||||
|
NatureCNN,
|
||||||
|
create_mlp,
|
||||||
|
get_actor_critic_arch,
|
||||||
|
)
|
||||||
|
from stable_baselines3.common.type_aliases import Schedule
|
||||||
|
|
||||||
|
# CAP the standard deviation of the actor
|
||||||
|
LOG_STD_MAX = 2
|
||||||
|
LOG_STD_MIN = -20
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(BasePolicy):
|
||||||
|
"""
|
||||||
|
Actor network (policy) for SAC.
|
||||||
|
|
||||||
|
:param observation_space: Obervation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param net_arch: Network architecture
|
||||||
|
:param features_extractor: Network to extract features
|
||||||
|
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||||
|
:param features_dim: Number of features
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
||||||
|
for the std instead of only (n_features,) when using gSDE.
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
net_arch: List[int],
|
||||||
|
features_extractor: nn.Module,
|
||||||
|
features_dim: int,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = -3,
|
||||||
|
full_std: bool = True,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
clip_mean: float = 2.0,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
features_extractor=features_extractor,
|
||||||
|
normalize_images=normalize_images,
|
||||||
|
squash_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save arguments to re-create object at loading
|
||||||
|
self.use_sde = use_sde
|
||||||
|
self.sde_features_extractor = None
|
||||||
|
self.net_arch = net_arch
|
||||||
|
self.features_dim = features_dim
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.log_std_init = log_std_init
|
||||||
|
self.sde_net_arch = sde_net_arch
|
||||||
|
self.use_expln = use_expln
|
||||||
|
self.full_std = full_std
|
||||||
|
self.clip_mean = clip_mean
|
||||||
|
|
||||||
|
if sde_net_arch is not None:
|
||||||
|
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||||
|
|
||||||
|
action_dim = get_action_dim(self.action_space)
|
||||||
|
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
|
||||||
|
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||||
|
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||||
|
|
||||||
|
if self.use_sde:
|
||||||
|
self.action_dist = StateDependentNoiseDistribution(
|
||||||
|
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
|
||||||
|
)
|
||||||
|
self.mu, self.log_std = self.action_dist.proba_distribution_net(
|
||||||
|
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
|
||||||
|
)
|
||||||
|
# Avoid numerical issues by limiting the mean of the Gaussian
|
||||||
|
# to be in [-clip_mean, clip_mean]
|
||||||
|
if clip_mean > 0.0:
|
||||||
|
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
|
||||||
|
else:
|
||||||
|
self.action_dist = SquashedDiagGaussianDistribution(action_dim)
|
||||||
|
self.mu = nn.Linear(last_layer_dim, action_dim)
|
||||||
|
self.log_std = nn.Linear(last_layer_dim, action_dim)
|
||||||
|
|
||||||
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
dict(
|
||||||
|
net_arch=self.net_arch,
|
||||||
|
features_dim=self.features_dim,
|
||||||
|
activation_fn=self.activation_fn,
|
||||||
|
use_sde=self.use_sde,
|
||||||
|
log_std_init=self.log_std_init,
|
||||||
|
full_std=self.full_std,
|
||||||
|
use_expln=self.use_expln,
|
||||||
|
features_extractor=self.features_extractor,
|
||||||
|
clip_mean=self.clip_mean,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_std(self) -> th.Tensor:
|
||||||
|
"""
|
||||||
|
Retrieve the standard deviation of the action distribution.
|
||||||
|
Only useful when using gSDE.
|
||||||
|
It corresponds to ``th.exp(log_std)`` in the normal case,
|
||||||
|
but is slightly different when using ``expln`` function
|
||||||
|
(cf StateDependentNoiseDistribution doc).
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
msg = "get_std() is only available when using gSDE"
|
||||||
|
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||||
|
return self.action_dist.get_std(self.log_std)
|
||||||
|
|
||||||
|
def reset_noise(self, batch_size: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
Sample new weights for the exploration matrix, when using gSDE.
|
||||||
|
|
||||||
|
:param batch_size:
|
||||||
|
"""
|
||||||
|
msg = "reset_noise() is only available when using gSDE"
|
||||||
|
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||||
|
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||||
|
|
||||||
|
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||||
|
"""
|
||||||
|
Get the parameters for the action distribution.
|
||||||
|
|
||||||
|
:param obs:
|
||||||
|
:return:
|
||||||
|
Mean, standard deviation and optional keyword arguments.
|
||||||
|
"""
|
||||||
|
features = self.extract_features(obs)
|
||||||
|
latent_pi = self.latent_pi(features)
|
||||||
|
mean_actions = self.mu(latent_pi)
|
||||||
|
|
||||||
|
if self.use_sde:
|
||||||
|
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
|
||||||
|
# Unstructured exploration (Original implementation)
|
||||||
|
log_std = self.log_std(latent_pi)
|
||||||
|
# Original Implementation to cap the standard deviation
|
||||||
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
|
return mean_actions, log_std, {}
|
||||||
|
|
||||||
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
|
# Note: the action is squashed
|
||||||
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||||
|
|
||||||
|
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
|
# return action and associated log prob
|
||||||
|
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||||
|
|
||||||
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
|
return self(observation, deterministic)
|
||||||
|
|
||||||
|
|
||||||
|
class SACPolicy(BasePolicy):
|
||||||
|
"""
|
||||||
|
Policy class (with both actor and critic) for SAC.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param net_arch: The specification of the policy and value networks.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param features_extractor_kwargs: Keyword arguments
|
||||||
|
to pass to the features extractor.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param n_critics: Number of critic networks to create.
|
||||||
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
|
between the actor and the critic (this saves computation time)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = -3,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
clip_mean: float = 2.0,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
n_critics: int = 2,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
optimizer_class=optimizer_class,
|
||||||
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
squash_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if net_arch is None:
|
||||||
|
if features_extractor_class == NatureCNN:
|
||||||
|
net_arch = []
|
||||||
|
else:
|
||||||
|
net_arch = [256, 256]
|
||||||
|
|
||||||
|
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||||
|
|
||||||
|
self.net_arch = net_arch
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.net_args = {
|
||||||
|
"observation_space": self.observation_space,
|
||||||
|
"action_space": self.action_space,
|
||||||
|
"net_arch": actor_arch,
|
||||||
|
"activation_fn": self.activation_fn,
|
||||||
|
"normalize_images": normalize_images,
|
||||||
|
}
|
||||||
|
self.actor_kwargs = self.net_args.copy()
|
||||||
|
|
||||||
|
if sde_net_arch is not None:
|
||||||
|
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||||
|
|
||||||
|
sde_kwargs = {
|
||||||
|
"use_sde": use_sde,
|
||||||
|
"log_std_init": log_std_init,
|
||||||
|
"use_expln": use_expln,
|
||||||
|
"clip_mean": clip_mean,
|
||||||
|
}
|
||||||
|
self.actor_kwargs.update(sde_kwargs)
|
||||||
|
self.critic_kwargs = self.net_args.copy()
|
||||||
|
self.critic_kwargs.update(
|
||||||
|
{
|
||||||
|
"n_critics": n_critics,
|
||||||
|
"net_arch": critic_arch,
|
||||||
|
"share_features_extractor": share_features_extractor,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.actor, self.actor_target = None, None
|
||||||
|
self.critic, self.critic_target = None, None
|
||||||
|
self.share_features_extractor = share_features_extractor
|
||||||
|
|
||||||
|
self._build(lr_schedule)
|
||||||
|
|
||||||
|
def _build(self, lr_schedule: Schedule) -> None:
|
||||||
|
self.actor = self.make_actor()
|
||||||
|
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
|
if self.share_features_extractor:
|
||||||
|
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||||
|
# Do not optimize the shared features extractor with the critic loss
|
||||||
|
# otherwise, there are gradient computation issues
|
||||||
|
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
||||||
|
else:
|
||||||
|
# Create a separate features extractor for the critic
|
||||||
|
# this requires more memory and computation
|
||||||
|
self.critic = self.make_critic(features_extractor=None)
|
||||||
|
critic_parameters = self.critic.parameters()
|
||||||
|
|
||||||
|
# Critic target should not share the features extractor with critic
|
||||||
|
self.critic_target = self.make_critic(features_extractor=None)
|
||||||
|
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
|
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
|
# Target networks should always be in eval mode
|
||||||
|
self.critic_target.set_training_mode(False)
|
||||||
|
|
||||||
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||||
|
data = super()._get_constructor_parameters()
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
dict(
|
||||||
|
net_arch=self.net_arch,
|
||||||
|
activation_fn=self.net_args["activation_fn"],
|
||||||
|
use_sde=self.actor_kwargs["use_sde"],
|
||||||
|
log_std_init=self.actor_kwargs["log_std_init"],
|
||||||
|
use_expln=self.actor_kwargs["use_expln"],
|
||||||
|
clip_mean=self.actor_kwargs["clip_mean"],
|
||||||
|
n_critics=self.critic_kwargs["n_critics"],
|
||||||
|
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||||
|
optimizer_class=self.optimizer_class,
|
||||||
|
optimizer_kwargs=self.optimizer_kwargs,
|
||||||
|
features_extractor_class=self.features_extractor_class,
|
||||||
|
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def reset_noise(self, batch_size: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
Sample new weights for the exploration matrix, when using gSDE.
|
||||||
|
|
||||||
|
:param batch_size:
|
||||||
|
"""
|
||||||
|
self.actor.reset_noise(batch_size=batch_size)
|
||||||
|
|
||||||
|
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
||||||
|
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
||||||
|
return Actor(**actor_kwargs).to(self.device)
|
||||||
|
|
||||||
|
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
|
||||||
|
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||||
|
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||||
|
|
||||||
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
|
return self._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
|
return self.actor(observation, deterministic)
|
||||||
|
|
||||||
|
def set_training_mode(self, mode: bool) -> None:
|
||||||
|
"""
|
||||||
|
Put the policy in either training or evaluation mode.
|
||||||
|
|
||||||
|
This affects certain modules, such as batch normalisation and dropout.
|
||||||
|
|
||||||
|
:param mode: if true, set to training mode, else set to evaluation mode
|
||||||
|
"""
|
||||||
|
self.actor.set_training_mode(mode)
|
||||||
|
self.critic.set_training_mode(mode)
|
||||||
|
self.training = mode
|
||||||
|
|
||||||
|
|
||||||
|
MlpPolicy = SACPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class CnnPolicy(SACPolicy):
|
||||||
|
"""
|
||||||
|
Policy class (with both actor and critic) for SAC.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param net_arch: The specification of the policy and value networks.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param n_critics: Number of critic networks to create.
|
||||||
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
|
between the actor and the critic (this saves computation time)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = -3,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
clip_mean: float = 2.0,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
n_critics: int = 2,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
lr_schedule,
|
||||||
|
net_arch,
|
||||||
|
activation_fn,
|
||||||
|
use_sde,
|
||||||
|
log_std_init,
|
||||||
|
sde_net_arch,
|
||||||
|
use_expln,
|
||||||
|
clip_mean,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
normalize_images,
|
||||||
|
optimizer_class,
|
||||||
|
optimizer_kwargs,
|
||||||
|
n_critics,
|
||||||
|
share_features_extractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiInputPolicy(SACPolicy):
|
||||||
|
"""
|
||||||
|
Policy class (with both actor and critic) for SAC.
|
||||||
|
|
||||||
|
:param observation_space: Observation space
|
||||||
|
:param action_space: Action space
|
||||||
|
:param lr_schedule: Learning rate schedule (could be constant)
|
||||||
|
:param net_arch: The specification of the policy and value networks.
|
||||||
|
:param activation_fn: Activation function
|
||||||
|
:param use_sde: Whether to use State Dependent Exploration or not
|
||||||
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
|
:param sde_net_arch: Network architecture for extracting features
|
||||||
|
when using gSDE. If None, the latent features from the policy will be used.
|
||||||
|
Pass an empty list to use the states as features.
|
||||||
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||||
|
a positive standard deviation (cf paper). It allows to keep variance
|
||||||
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||||
|
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
||||||
|
:param features_extractor_class: Features extractor to use.
|
||||||
|
:param normalize_images: Whether to normalize images or not,
|
||||||
|
dividing by 255.0 (True by default)
|
||||||
|
:param optimizer_class: The optimizer to use,
|
||||||
|
``th.optim.Adam`` by default
|
||||||
|
:param optimizer_kwargs: Additional keyword arguments,
|
||||||
|
excluding the learning rate, to pass to the optimizer
|
||||||
|
:param n_critics: Number of critic networks to create.
|
||||||
|
:param share_features_extractor: Whether to share or not the features extractor
|
||||||
|
between the actor and the critic (this saves computation time)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
lr_schedule: Schedule,
|
||||||
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
|
use_sde: bool = False,
|
||||||
|
log_std_init: float = -3,
|
||||||
|
sde_net_arch: Optional[List[int]] = None,
|
||||||
|
use_expln: bool = False,
|
||||||
|
clip_mean: float = 2.0,
|
||||||
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
||||||
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
normalize_images: bool = True,
|
||||||
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
n_critics: int = 2,
|
||||||
|
share_features_extractor: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
lr_schedule,
|
||||||
|
net_arch,
|
||||||
|
activation_fn,
|
||||||
|
use_sde,
|
||||||
|
log_std_init,
|
||||||
|
sde_net_arch,
|
||||||
|
use_expln,
|
||||||
|
clip_mean,
|
||||||
|
features_extractor_class,
|
||||||
|
features_extractor_kwargs,
|
||||||
|
normalize_images,
|
||||||
|
optimizer_class,
|
||||||
|
optimizer_kwargs,
|
||||||
|
n_critics,
|
||||||
|
share_features_extractor,
|
||||||
|
)
|
406
metastable_baselines/sac/sac.py
Normal file
406
metastable_baselines/sac/sac.py
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
from stable_baselines3.common.utils import polyak_update
|
||||||
|
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||||
|
|
||||||
|
from ..misc.distTools import new_dist_like
|
||||||
|
|
||||||
|
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||||
|
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||||
|
from ..projections.w2_projection_layer import WassersteinProjectionLayer
|
||||||
|
from ..projections.kl_projection_layer import KLProjectionLayer
|
||||||
|
|
||||||
|
|
||||||
|
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||||
|
|
||||||
|
# CAP the standard deviation of the actor
|
||||||
|
LOG_STD_MAX = 2
|
||||||
|
LOG_STD_MIN = -20
|
||||||
|
|
||||||
|
|
||||||
|
class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
|
||||||
|
"""
|
||||||
|
Trust Region Layers (TRL) based on SAC (Soft Actor Critic)
|
||||||
|
This implementation is almost a 1:1-copy of the sb3-code for SAC.
|
||||||
|
Only minor changes have been made to implement Differential Trust Region Layers
|
||||||
|
|
||||||
|
Description from original SAC implementation:
|
||||||
|
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
||||||
|
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
|
||||||
|
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
|
||||||
|
(https://github.com/rail-berkeley/softlearning/)
|
||||||
|
and from Stable Baselines (https://github.com/hill-a/stable-baselines)
|
||||||
|
Paper: https://arxiv.org/abs/1801.01290
|
||||||
|
Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
|
||||||
|
|
||||||
|
Note: we use double q target and not value target as discussed
|
||||||
|
in https://github.com/hill-a/stable-baselines/issues/270
|
||||||
|
|
||||||
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||||
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||||
|
:param learning_rate: learning rate for adam optimizer,
|
||||||
|
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
|
||||||
|
it can be a function of the current progress remaining (from 1 to 0)
|
||||||
|
:param buffer_size: size of the replay buffer
|
||||||
|
:param learning_starts: how many steps of the model to collect transitions for before learning starts
|
||||||
|
:param batch_size: Minibatch size for each gradient update
|
||||||
|
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
|
||||||
|
:param gamma: the discount factor
|
||||||
|
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
|
||||||
|
like ``(5, "step")`` or ``(2, "episode")``.
|
||||||
|
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
|
||||||
|
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||||
|
during the rollout.
|
||||||
|
:param action_noise: the action noise type (None by default), this can help
|
||||||
|
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||||
|
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
|
||||||
|
If ``None``, it will be automatically selected.
|
||||||
|
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
|
||||||
|
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||||
|
at a cost of more complexity.
|
||||||
|
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||||
|
:param ent_coef: Entropy regularization coefficient. (Equivalent to
|
||||||
|
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
|
||||||
|
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
|
||||||
|
:param target_update_interval: update the target network every ``target_network_update_freq``
|
||||||
|
gradient steps.
|
||||||
|
:param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
|
||||||
|
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
||||||
|
instead of action noise exploration (default: False)
|
||||||
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||||
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
|
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
||||||
|
during the warm up phase (before learning starts)
|
||||||
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
|
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||||
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||||
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
||||||
|
:param seed: Seed for the pseudo random generators
|
||||||
|
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||||
|
Setting it to auto, the code will be run on the GPU if possible.
|
||||||
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
|
"MlpPolicy": MlpPolicy,
|
||||||
|
"CnnPolicy": CnnPolicy,
|
||||||
|
"MultiInputPolicy": MultiInputPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: Union[str, Type[SACPolicy]],
|
||||||
|
env: Union[GymEnv, str],
|
||||||
|
learning_rate: Union[float, Schedule] = 3e-4,
|
||||||
|
buffer_size: int = 1_000_000, # 1e6
|
||||||
|
learning_starts: int = 100,
|
||||||
|
batch_size: int = 256,
|
||||||
|
tau: float = 0.005,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
train_freq: Union[int, Tuple[int, str]] = 1,
|
||||||
|
gradient_steps: int = 1,
|
||||||
|
action_noise: Optional[ActionNoise] = None,
|
||||||
|
replay_buffer_class: Optional[ReplayBuffer] = None,
|
||||||
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
optimize_memory_usage: bool = False,
|
||||||
|
ent_coef: Union[str, float] = "auto",
|
||||||
|
target_update_interval: int = 1,
|
||||||
|
target_entropy: Union[str, float] = "auto",
|
||||||
|
use_sde: bool = False,
|
||||||
|
sde_sample_freq: int = -1,
|
||||||
|
use_sde_at_warmup: bool = False,
|
||||||
|
tensorboard_log: Optional[str] = None,
|
||||||
|
create_eval_env: bool = False,
|
||||||
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
device: Union[th.device, str] = "auto",
|
||||||
|
|
||||||
|
# Different from SAC:
|
||||||
|
projection: BaseProjectionLayer = KLProjectionLayer(),
|
||||||
|
|
||||||
|
_init_setup_model: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
policy,
|
||||||
|
env,
|
||||||
|
learning_rate,
|
||||||
|
buffer_size,
|
||||||
|
learning_starts,
|
||||||
|
batch_size,
|
||||||
|
tau,
|
||||||
|
gamma,
|
||||||
|
train_freq,
|
||||||
|
gradient_steps,
|
||||||
|
action_noise,
|
||||||
|
replay_buffer_class=replay_buffer_class,
|
||||||
|
replay_buffer_kwargs=replay_buffer_kwargs,
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
tensorboard_log=tensorboard_log,
|
||||||
|
verbose=verbose,
|
||||||
|
device=device,
|
||||||
|
create_eval_env=create_eval_env,
|
||||||
|
seed=seed,
|
||||||
|
use_sde=use_sde,
|
||||||
|
sde_sample_freq=sde_sample_freq,
|
||||||
|
use_sde_at_warmup=use_sde_at_warmup,
|
||||||
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
|
supported_action_spaces=(gym.spaces.Box),
|
||||||
|
support_multi_env=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise Exception('TRL_SAC is not yet implemented')
|
||||||
|
|
||||||
|
self.target_entropy = target_entropy
|
||||||
|
self.log_ent_coef = None # type: Optional[th.Tensor]
|
||||||
|
# Entropy coefficient / Entropy temperature
|
||||||
|
# Inverse of the reward scale
|
||||||
|
self.ent_coef = ent_coef
|
||||||
|
self.target_update_interval = target_update_interval
|
||||||
|
self.ent_coef_optimizer = None
|
||||||
|
|
||||||
|
# Different from SAC:
|
||||||
|
self.projection = projection
|
||||||
|
self._global_steps = 0
|
||||||
|
|
||||||
|
if _init_setup_model:
|
||||||
|
self._setup_model()
|
||||||
|
|
||||||
|
def _setup_model(self) -> None:
|
||||||
|
super()._setup_model()
|
||||||
|
self._create_aliases()
|
||||||
|
# Target entropy is used when learning the entropy coefficient
|
||||||
|
if self.target_entropy == "auto":
|
||||||
|
# automatically set target entropy if needed
|
||||||
|
self.target_entropy = - \
|
||||||
|
np.prod(self.env.action_space.shape).astype(np.float32)
|
||||||
|
else:
|
||||||
|
# Force conversion
|
||||||
|
# this will also throw an error for unexpected string
|
||||||
|
self.target_entropy = float(self.target_entropy)
|
||||||
|
|
||||||
|
# The entropy coefficient or entropy can be learned automatically
|
||||||
|
# see Automating Entropy Adjustment for Maximum Entropy RL section
|
||||||
|
# of https://arxiv.org/abs/1812.05905
|
||||||
|
if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
|
||||||
|
# Default initial value of ent_coef when learned
|
||||||
|
init_value = 1.0
|
||||||
|
if "_" in self.ent_coef:
|
||||||
|
init_value = float(self.ent_coef.split("_")[1])
|
||||||
|
assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
|
||||||
|
|
||||||
|
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
|
||||||
|
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
|
||||||
|
self.log_ent_coef = th.log(
|
||||||
|
th.ones(1, device=self.device) * init_value).requires_grad_(True)
|
||||||
|
self.ent_coef_optimizer = th.optim.Adam(
|
||||||
|
[self.log_ent_coef], lr=self.lr_schedule(1))
|
||||||
|
else:
|
||||||
|
# Force conversion to float
|
||||||
|
# this will throw an error if a malformed string (different from 'auto')
|
||||||
|
# is passed
|
||||||
|
self.ent_coef_tensor = th.tensor(
|
||||||
|
float(self.ent_coef)).to(self.device)
|
||||||
|
|
||||||
|
def _create_aliases(self) -> None:
|
||||||
|
self.actor = self.policy.actor
|
||||||
|
self.critic = self.policy.critic
|
||||||
|
self.critic_target = self.policy.critic_target
|
||||||
|
|
||||||
|
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||||
|
# Switch to train mode (this affects batch norm / dropout)
|
||||||
|
self.policy.set_training_mode(True)
|
||||||
|
# Update optimizers learning rate
|
||||||
|
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
||||||
|
if self.ent_coef_optimizer is not None:
|
||||||
|
optimizers += [self.ent_coef_optimizer]
|
||||||
|
|
||||||
|
# Update learning rate according to lr schedule
|
||||||
|
self._update_learning_rate(optimizers)
|
||||||
|
|
||||||
|
ent_coef_losses, ent_coefs = [], []
|
||||||
|
actor_losses, critic_losses = [], []
|
||||||
|
|
||||||
|
for gradient_step in range(gradient_steps):
|
||||||
|
# Sample replay buffer
|
||||||
|
replay_data = self.replay_buffer.sample(
|
||||||
|
batch_size, env=self._vec_normalize_env)
|
||||||
|
|
||||||
|
# This is new compared to SAC.
|
||||||
|
# Calculating the TR-Projections we need to know the step number
|
||||||
|
self._global_steps += 1
|
||||||
|
|
||||||
|
# We need to sample because `log_std` may have changed between two gradient steps
|
||||||
|
if self.use_sde:
|
||||||
|
self.actor.reset_noise()
|
||||||
|
|
||||||
|
#################
|
||||||
|
# Orig Code:
|
||||||
|
# Action by the current actor for the sampled state
|
||||||
|
# actions_pi, log_prob = self.actor.action_log_prob(
|
||||||
|
# replay_data.observations)
|
||||||
|
# log_prob = log_prob.reshape(-1, 1)
|
||||||
|
|
||||||
|
act = self.actor
|
||||||
|
features = act.extract_features(replay_data.observations)
|
||||||
|
latent_pi = act.latent_pi(features)
|
||||||
|
mean_actions = act.mu(latent_pi)
|
||||||
|
|
||||||
|
# TODO: Allow contextual covariance with sde
|
||||||
|
if self.use_sde:
|
||||||
|
log_std = act.log_std
|
||||||
|
else:
|
||||||
|
# Unstructured exploration (Original implementation)
|
||||||
|
log_std = act.log_std(latent_pi)
|
||||||
|
# Original Implementation to cap the standard deviation
|
||||||
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
|
|
||||||
|
act_dist = self.action_dist
|
||||||
|
# internal A
|
||||||
|
if self.use_sde:
|
||||||
|
actions_pi = self.actions_from_params(
|
||||||
|
mean_actions, log_std, latent_pi) # latent_pi = latent_sde
|
||||||
|
else:
|
||||||
|
actions_pi = act_dist.actions_from_params(
|
||||||
|
mean_actions, log_std)
|
||||||
|
|
||||||
|
p_dist = self.action_dist.distribution
|
||||||
|
q_dist = new_dist_like(
|
||||||
|
p_dist, replay_data.means, replay_data.stds)
|
||||||
|
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||||
|
log_prob = proj_p.log_prob(actions_pi)
|
||||||
|
log_prob = log_prob.reshape(-1, 1)
|
||||||
|
|
||||||
|
####################
|
||||||
|
|
||||||
|
ent_coef_loss = None
|
||||||
|
if self.ent_coef_optimizer is not None:
|
||||||
|
# Important: detach the variable from the graph
|
||||||
|
# so we don't change it with other losses
|
||||||
|
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||||
|
ent_coef = th.exp(self.log_ent_coef.detach())
|
||||||
|
ent_coef_loss = - \
|
||||||
|
(self.log_ent_coef * (log_prob +
|
||||||
|
self.target_entropy).detach()).mean()
|
||||||
|
ent_coef_losses.append(ent_coef_loss.item())
|
||||||
|
else:
|
||||||
|
ent_coef = self.ent_coef_tensor
|
||||||
|
|
||||||
|
ent_coefs.append(ent_coef.item())
|
||||||
|
|
||||||
|
# Optimize entropy coefficient, also called
|
||||||
|
# entropy temperature or alpha in the paper
|
||||||
|
if ent_coef_loss is not None:
|
||||||
|
self.ent_coef_optimizer.zero_grad()
|
||||||
|
ent_coef_loss.backward()
|
||||||
|
self.ent_coef_optimizer.step()
|
||||||
|
|
||||||
|
with th.no_grad():
|
||||||
|
# Select action according to policy
|
||||||
|
next_actions, next_log_prob = self.actor.action_log_prob(
|
||||||
|
replay_data.next_observations)
|
||||||
|
# Compute the next Q values: min over all critics targets
|
||||||
|
next_q_values = th.cat(self.critic_target(
|
||||||
|
replay_data.next_observations, next_actions), dim=1)
|
||||||
|
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
|
||||||
|
# add entropy term
|
||||||
|
next_q_values = next_q_values - \
|
||||||
|
ent_coef * next_log_prob.reshape(-1, 1)
|
||||||
|
# td error + entropy term
|
||||||
|
target_q_values = replay_data.rewards + \
|
||||||
|
(1 - replay_data.dones) * self.gamma * next_q_values
|
||||||
|
|
||||||
|
# Get current Q-values estimates for each critic network
|
||||||
|
# using action from the replay buffer
|
||||||
|
current_q_values = self.critic(
|
||||||
|
replay_data.observations, replay_data.actions)
|
||||||
|
|
||||||
|
projection_loss = th.zeros(1)
|
||||||
|
|
||||||
|
# Compute critic loss
|
||||||
|
critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values)
|
||||||
|
for current_q in current_q_values)
|
||||||
|
critic_loss = critic_loss_raw + projection_loss
|
||||||
|
critic_losses.append(critic_loss.item())
|
||||||
|
|
||||||
|
# Optimize the critic
|
||||||
|
self.critic.optimizer.zero_grad()
|
||||||
|
critic_loss.backward()
|
||||||
|
self.critic.optimizer.step()
|
||||||
|
|
||||||
|
# Compute actor loss
|
||||||
|
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
|
||||||
|
# Mean over all critic networks
|
||||||
|
q_values_pi = th.cat(self.critic(
|
||||||
|
replay_data.observations, actions_pi), dim=1)
|
||||||
|
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
||||||
|
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
|
||||||
|
actor_losses.append(actor_loss.item())
|
||||||
|
|
||||||
|
# Optimize the actor
|
||||||
|
self.actor.optimizer.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self.actor.optimizer.step()
|
||||||
|
|
||||||
|
# Update target networks
|
||||||
|
if gradient_step % self.target_update_interval == 0:
|
||||||
|
polyak_update(self.critic.parameters(),
|
||||||
|
self.critic_target.parameters(), self.tau)
|
||||||
|
|
||||||
|
self._n_updates += gradient_steps
|
||||||
|
|
||||||
|
self.logger.record("train/n_updates",
|
||||||
|
self._n_updates, exclude="tensorboard")
|
||||||
|
self.logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||||
|
self.logger.record("train/actor_loss", np.mean(actor_losses))
|
||||||
|
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||||
|
if len(ent_coef_losses) > 0:
|
||||||
|
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||||
|
|
||||||
|
def learn(
|
||||||
|
self,
|
||||||
|
total_timesteps: int,
|
||||||
|
callback: MaybeCallback = None,
|
||||||
|
log_interval: int = 4,
|
||||||
|
eval_env: Optional[GymEnv] = None,
|
||||||
|
eval_freq: int = -1,
|
||||||
|
n_eval_episodes: int = 5,
|
||||||
|
tb_log_name: str = "SAC",
|
||||||
|
eval_log_path: Optional[str] = None,
|
||||||
|
reset_num_timesteps: bool = True,
|
||||||
|
) -> OffPolicyAlgorithm:
|
||||||
|
|
||||||
|
return super().learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
callback=callback,
|
||||||
|
log_interval=log_interval,
|
||||||
|
eval_env=eval_env,
|
||||||
|
eval_freq=eval_freq,
|
||||||
|
n_eval_episodes=n_eval_episodes,
|
||||||
|
tb_log_name=tb_log_name,
|
||||||
|
eval_log_path=eval_log_path,
|
||||||
|
reset_num_timesteps=reset_num_timesteps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _excluded_save_params(self) -> List[str]:
|
||||||
|
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||||
|
|
||||||
|
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||||
|
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||||
|
if self.ent_coef_optimizer is not None:
|
||||||
|
saved_pytorch_variables = ["log_ent_coef"]
|
||||||
|
state_dicts.append("ent_coef_optimizer")
|
||||||
|
else:
|
||||||
|
saved_pytorch_variables = ["ent_coef_tensor"]
|
||||||
|
return state_dicts, saved_pytorch_variables
|
18
test.py
18
test.py
@ -6,12 +6,12 @@ import os
|
|||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from stable_baselines3 import SAC, PPO, A2C
|
from stable_baselines3 import
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||||
|
|
||||||
from metastable_baselines.trl_pg import TRL_PG
|
from metastable_baselines.ppo import PPO
|
||||||
from metastable_baselines.trl_pg.policies import MlpPolicy
|
from metastable_baselines.ppo.policies import MlpPolicy
|
||||||
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||||
import columbus
|
import columbus
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ root_path = '.'
|
|||||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
use_sde = False
|
use_sde = False
|
||||||
ppo = TRL_PG(
|
ppo = PPO(
|
||||||
MlpPolicy,
|
MlpPolicy,
|
||||||
env,
|
env,
|
||||||
verbose=0,
|
verbose=0,
|
||||||
@ -37,13 +37,13 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
|||||||
use_sde=use_sde, # False
|
use_sde=use_sde, # False
|
||||||
clip_range=0.2,
|
clip_range=0.2,
|
||||||
)
|
)
|
||||||
trl_pg = TRL_PG(
|
trl_frob = PPO(
|
||||||
MlpPolicy,
|
MlpPolicy,
|
||||||
env,
|
env,
|
||||||
projection=FrobeniusProjectionLayer(),
|
projection=FrobeniusProjectionLayer(),
|
||||||
verbose=0,
|
verbose=0,
|
||||||
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
||||||
"/trl_pg"+(['', '_sde'][use_sde])+"/",
|
"/trl_frob"+(['', '_sde'][use_sde])+"/",
|
||||||
learning_rate=3e-4,
|
learning_rate=3e-4,
|
||||||
gamma=0.99,
|
gamma=0.99,
|
||||||
gae_lambda=0.95,
|
gae_lambda=0.95,
|
||||||
@ -54,12 +54,12 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
|||||||
clip_range=2, # 0.2
|
clip_range=2, # 0.2
|
||||||
)
|
)
|
||||||
|
|
||||||
print('TRL_PG:')
|
|
||||||
testModel(trl_pg, timesteps, showRes,
|
|
||||||
saveModel, n_eval_episodes)
|
|
||||||
print('PPO:')
|
print('PPO:')
|
||||||
testModel(ppo, timesteps, showRes,
|
testModel(ppo, timesteps, showRes,
|
||||||
saveModel, n_eval_episodes)
|
saveModel, n_eval_episodes)
|
||||||
|
print('TRL_frob:')
|
||||||
|
testModel(trl_frob, timesteps, showRes,
|
||||||
|
saveModel, n_eval_episodes)
|
||||||
|
|
||||||
|
|
||||||
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
||||||
|
Loading…
Reference in New Issue
Block a user