diff --git a/sb3_trl/__init__.py b/sb3_trl/__init__.py new file mode 100644 index 0000000..1332cb3 --- /dev/null +++ b/sb3_trl/__init__.py @@ -0,0 +1 @@ +# TODO: License diff --git a/sb3_trl/trl_pg/__init__.py b/sb3_trl/trl_pg/__init__.py new file mode 100644 index 0000000..66cf37e --- /dev/null +++ b/sb3_trl/trl_pg/__init__.py @@ -0,0 +1,2 @@ +from sb3_trl.trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from sb3_trl.trl_pg.trl_pg import TRL_PG diff --git a/sb3_trl/trl_pg/policies.py b/sb3_trl/trl_pg/policies.py new file mode 100644 index 0000000..8b784eb --- /dev/null +++ b/sb3_trl/trl_pg/policies.py @@ -0,0 +1,7 @@ +# This file is here just to define MlpPolicy/CnnPolicy +# that work for TRL_PG +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy + +MlpPolicy = ActorCriticPolicy +CnnPolicy = ActorCriticCnnPolicy +MultiInputPolicy = MultiInputActorCriticPolicy diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py new file mode 100644 index 0000000..9f524e6 --- /dev/null +++ b/sb3_trl/trl_pg/trl_pg.py @@ -0,0 +1,340 @@ +import warnings +from typing import Any, Dict, Optional, Type, Union + +import numpy as np +import torch as th +from gym import spaces +from torch.nn import functional as F + +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn + + +class TRL_PG(OnPolicyAlgorithm): + """ + Differential Trust Region Layer (TRL) for Policy Gradient (PG) + + Paper: https://arxiv.org/abs/2101.09207 + Code: This implementation borrows (/steals most) code from SB3's PPO implementation https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py + The implementation of the TRL-specific parts borrows from https://github.com/boschresearch/trust-region-layers/blob/main/trust_region_projections/algorithms/pg/pg.py + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) + NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) + See https://github.com/pytorch/pytorch/issues/29372 + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + #TODO: Add new params to doc + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: int = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + + # Different from PPO: + #projection: BaseProjectionLayer = None, + projection = None, + + _init_setup_model: bool = True, + ): + + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert ( + buffer_size > 1 + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + + # Different from PPO: + self.projection = projection + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + surrogate_losses = [] + entropy_losses = [] + trust_region_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) + + values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' + # clipped surrogate loss + surrogate_loss_1 = advantages * ratio + surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean() + + surrogate_losses.append(surrogate_loss.item()) + + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob) + else: + entropy_loss = -th.mean(entropy) + + entropy_losses.append(entropy_loss.item()) + + # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss + #trust_region_loss = self.projection.get_trust_region_loss()#TODO: params + trust_region_loss = th.zeros(1, device=entropy_loss.device) # TODO: Implement + + trust_region_losses.append(trust_region_loss.item()) + + policy_loss = surrogate_loss + self.ent_coef * entropy_loss + trust_region_loss + pg_losses.append(policy_loss.item()) + + loss = policy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/surrogate_loss", np.mean(surrogate_losses)) + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/trust_region_loss", np.mean(trust_region_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "PPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "TRL_PG": + + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) diff --git a/sb3_trl/trl_sac/__init__.py b/sb3_trl/trl_sac/__init__.py new file mode 100644 index 0000000..c0e01b7 --- /dev/null +++ b/sb3_trl/trl_sac/__init__.py @@ -0,0 +1,2 @@ +from sb3_trl.trl_sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from sb3_trl.trl_sac.trl_sac import TRL_SAC diff --git a/sb3_trl/trl_sac/policies.py b/sb3_trl/trl_sac/policies.py new file mode 100644 index 0000000..6fcbea1 --- /dev/null +++ b/sb3_trl/trl_sac/policies.py @@ -0,0 +1,516 @@ +import warnings +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import torch as th +from torch import nn + +from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic +from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + NatureCNN, + create_mlp, + get_actor_critic_arch, +) +from stable_baselines3.common.type_aliases import Schedule + +# CAP the standard deviation of the actor +LOG_STD_MAX = 2 +LOG_STD_MIN = -20 + + +class Actor(BasePolicy): + """ + Actor network (policy) for SAC. + + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: Number of features + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE. + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + normalize_images: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + squash_output=True, + ) + + # Save arguments to re-create object at loading + self.use_sde = use_sde + self.sde_features_extractor = None + self.net_arch = net_arch + self.features_dim = features_dim + self.activation_fn = activation_fn + self.log_std_init = log_std_init + self.sde_net_arch = sde_net_arch + self.use_expln = use_expln + self.full_std = full_std + self.clip_mean = clip_mean + + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + + action_dim = get_action_dim(self.action_space) + latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) + self.latent_pi = nn.Sequential(*latent_pi_net) + last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim + + if self.use_sde: + self.action_dist = StateDependentNoiseDistribution( + action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True + ) + self.mu, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init + ) + # Avoid numerical issues by limiting the mean of the Gaussian + # to be in [-clip_mean, clip_mean] + if clip_mean > 0.0: + self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) + else: + self.action_dist = SquashedDiagGaussianDistribution(action_dim) + self.mu = nn.Linear(last_layer_dim, action_dim) + self.log_std = nn.Linear(last_layer_dim, action_dim) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + full_std=self.full_std, + use_expln=self.use_expln, + features_extractor=self.features_extractor, + clip_mean=self.clip_mean, + ) + ) + return data + + def get_std(self) -> th.Tensor: + """ + Retrieve the standard deviation of the action distribution. + Only useful when using gSDE. + It corresponds to ``th.exp(log_std)`` in the normal case, + but is slightly different when using ``expln`` function + (cf StateDependentNoiseDistribution doc). + + :return: + """ + msg = "get_std() is only available when using gSDE" + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg + return self.action_dist.get_std(self.log_std) + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + + :param batch_size: + """ + msg = "reset_noise() is only available when using gSDE" + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg + self.action_dist.sample_weights(self.log_std, batch_size=batch_size) + + def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + """ + Get the parameters for the action distribution. + + :param obs: + :return: + Mean, standard deviation and optional keyword arguments. + """ + features = self.extract_features(obs) + latent_pi = self.latent_pi(features) + mean_actions = self.mu(latent_pi) + + if self.use_sde: + return mean_actions, self.log_std, dict(latent_sde=latent_pi) + # Unstructured exploration (Original implementation) + log_std = self.log_std(latent_pi) + # Original Implementation to cap the standard deviation + log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + return mean_actions, log_std, {} + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + # Note: the action is squashed + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) + + def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + # return action and associated log prob + return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self(observation, deterministic) + + +class SACPolicy(BasePolicy): + """ + Policy class (with both actor and critic) for SAC. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) + + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [256, 256] + + actor_arch, critic_arch = get_actor_critic_arch(net_arch) + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.net_args = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "net_arch": actor_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + } + self.actor_kwargs = self.net_args.copy() + + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + + sde_kwargs = { + "use_sde": use_sde, + "log_std_init": log_std_init, + "use_expln": use_expln, + "clip_mean": clip_mean, + } + self.actor_kwargs.update(sde_kwargs) + self.critic_kwargs = self.net_args.copy() + self.critic_kwargs.update( + { + "n_critics": n_critics, + "net_arch": critic_arch, + "share_features_extractor": share_features_extractor, + } + ) + + self.actor, self.actor_target = None, None + self.critic, self.critic_target = None, None + self.share_features_extractor = share_features_extractor + + self._build(lr_schedule) + + def _build(self, lr_schedule: Schedule) -> None: + self.actor = self.make_actor() + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + if self.share_features_extractor: + self.critic = self.make_critic(features_extractor=self.actor.features_extractor) + # Do not optimize the shared features extractor with the critic loss + # otherwise, there are gradient computation issues + critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + else: + # Create a separate features extractor for the critic + # this requires more memory and computation + self.critic = self.make_critic(features_extractor=None) + critic_parameters = self.critic.parameters() + + # Critic target should not share the features extractor with critic + self.critic_target = self.make_critic(features_extractor=None) + self.critic_target.load_state_dict(self.critic.state_dict()) + + self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + + # Target networks should always be in eval mode + self.critic_target.set_training_mode(False) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.net_args["activation_fn"], + use_sde=self.actor_kwargs["use_sde"], + log_std_init=self.actor_kwargs["log_std_init"], + use_expln=self.actor_kwargs["use_expln"], + clip_mean=self.actor_kwargs["clip_mean"], + n_critics=self.critic_kwargs["n_critics"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + + :param batch_size: + """ + self.actor.reset_noise(batch_size=batch_size) + + def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: + actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) + return Actor(**actor_kwargs).to(self.device) + + def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic: + critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) + return ContinuousCritic(**critic_kwargs).to(self.device) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self.actor(observation, deterministic) + + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.actor.set_training_mode(mode) + self.critic.set_training_mode(mode) + self.training = mode + + +MlpPolicy = SACPolicy + + +class CnnPolicy(SACPolicy): + """ + Policy class (with both actor and critic) for SAC. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + use_sde, + log_std_init, + sde_net_arch, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + ) + + +class MultiInputPolicy(SACPolicy): + """ + Policy class (with both actor and critic) for SAC. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + use_sde, + log_std_init, + sde_net_arch, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + ) diff --git a/sb3_trl/trl_sac/trl_sac.py b/sb3_trl/trl_sac/trl_sac.py new file mode 100644 index 0000000..2e884b9 --- /dev/null +++ b/sb3_trl/trl_sac/trl_sac.py @@ -0,0 +1,324 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from torch.nn import functional as F + +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import polyak_update +from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy + + +class TRL_SAC(OffPolicyAlgorithm): + """ + Trust Region Layers (TRL) based on SAC (Soft Actor Critic) + This implementation is almost a 1:1-copy of the sb3-code for SAC. + Only minor changes have been made to implement Differential Trust Region Layers + + Description from original SAC implementation: + Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, + This implementation borrows code from original implementation (https://github.com/haarnoja/sac) + from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo + (https://github.com/rail-berkeley/softlearning/) + and from Stable Baselines (https://github.com/hill-a/stable-baselines) + Paper: https://arxiv.org/abs/1801.01290 + Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html + + Note: we use double q target and not value target as discussed + in https://github.com/hill-a/stable-baselines/issues/270 + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: learning rate for adam optimizer, + the same learning rate will be used for all networks (Q-Values, Actor and Value function) + it can be a function of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param action_noise: the action noise type (None by default), this can help + for hard exploration problem. Cf common.noise for the different action noise type. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param ent_coef: Entropy regularization coefficient. (Equivalent to + inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. + Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) + :param target_update_interval: update the target network every ``target_network_update_freq`` + gradient steps. + :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling + during the warm up phase (before learning starts) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + + def __init__( + self, + policy: Union[str, Type[SACPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, + action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + ent_coef: Union[str, float] = "auto", + target_update_interval: int = 1, + target_entropy: Union[str, float] = "auto", + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super().__init__( + policy, + env, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + use_sde_at_warmup=use_sde_at_warmup, + optimize_memory_usage=optimize_memory_usage, + supported_action_spaces=(gym.spaces.Box), + support_multi_env=True, + ) + + self.target_entropy = target_entropy + self.log_ent_coef = None # type: Optional[th.Tensor] + # Entropy coefficient / Entropy temperature + # Inverse of the reward scale + self.ent_coef = ent_coef + self.target_update_interval = target_update_interval + self.ent_coef_optimizer = None + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + self._create_aliases() + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): + # Default initial value of ent_coef when learned + init_value = 1.0 + if "_" in self.ent_coef: + init_value = float(self.ent_coef.split("_")[1]) + assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) + self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) + else: + # Force conversion to float + # this will throw an error if a malformed string (different from 'auto') + # is passed + self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) + + def _create_aliases(self) -> None: + self.actor = self.policy.actor + self.critic = self.policy.critic + self.critic_target = self.policy.critic_target + + def train(self, gradient_steps: int, batch_size: int = 64) -> None: + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizers learning rate + optimizers = [self.actor.optimizer, self.critic.optimizer] + if self.ent_coef_optimizer is not None: + optimizers += [self.ent_coef_optimizer] + + # Update learning rate according to lr schedule + self._update_learning_rate(optimizers) + + ent_coef_losses, ent_coefs = [], [] + actor_losses, critic_losses = [], [] + + for gradient_step in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + # We need to sample because `log_std` may have changed between two gradient steps + if self.use_sde: + self.actor.reset_noise() + + # Action by the current actor for the sampled state + actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) + log_prob = log_prob.reshape(-1, 1) + + ent_coef_loss = None + if self.ent_coef_optimizer is not None: + # Important: detach the variable from the graph + # so we don't change it with other losses + # see https://github.com/rail-berkeley/softlearning/issues/60 + ent_coef = th.exp(self.log_ent_coef.detach()) + ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() + ent_coef_losses.append(ent_coef_loss.item()) + else: + ent_coef = self.ent_coef_tensor + + ent_coefs.append(ent_coef.item()) + + # Optimize entropy coefficient, also called + # entropy temperature or alpha in the paper + if ent_coef_loss is not None: + self.ent_coef_optimizer.zero_grad() + ent_coef_loss.backward() + self.ent_coef_optimizer.step() + + with th.no_grad(): + # Select action according to policy + next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) + # Compute the next Q values: min over all critics targets + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) + next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) + # add entropy term + next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) + # td error + entropy term + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values + + # Get current Q-values estimates for each critic network + # using action from the replay buffer + current_q_values = self.critic(replay_data.observations, replay_data.actions) + + # Compute critic loss + critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) + critic_losses.append(critic_loss.item()) + + # Optimize the critic + self.critic.optimizer.zero_grad() + critic_loss.backward() + self.critic.optimizer.step() + + # Compute actor loss + # Alternative: actor_loss = th.mean(log_prob - qf1_pi) + # Mean over all critic networks + q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) + min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) + actor_loss = (ent_coef * log_prob - min_qf_pi).mean() + actor_losses.append(actor_loss.item()) + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + # Update target networks + if gradient_step % self.target_update_interval == 0: + polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) + + self._n_updates += gradient_steps + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/ent_coef", np.mean(ent_coefs)) + self.logger.record("train/actor_loss", np.mean(actor_losses)) + self.logger.record("train/critic_loss", np.mean(critic_losses)) + if len(ent_coef_losses) > 0: + self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "SAC", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: + + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) + + def _excluded_save_params(self) -> List[str]: + return super()._excluded_save_params() + ["actor", "critic", "critic_target"] + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] + if self.ent_coef_optimizer is not None: + saved_pytorch_variables = ["log_ent_coef"] + state_dicts.append("ent_coef_optimizer") + else: + saved_pytorch_variables = ["ent_coef_tensor"] + return state_dicts, saved_pytorch_variables