import warnings from typing import Any, Dict, Optional, Type, Union, NamedTuple import numpy as np import torch as th from gym import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.utils import obs_as_tensor from stable_baselines3.common.vec_env import VecNormalize from ..misc.distTools import new_dist_like, new_dist_like_from_sqrt from metastable_projections.projections.base_projection_layer import BaseProjectionLayer from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer from metastable_projections.projections.kl_projection_layer import KLProjectionLayer from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): """ Differential Trust Region Layer (TRL) for Policy Gradient (PG) Paper: https://arxiv.org/abs/2101.09207 Code: This implementation borrows (/steals most) code from SB3's PPO implementation https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py The implementation of the TRL-specific parts borrows from https://github.com/boschresearch/trust-region-layers/blob/main/trust_region_projections/algorithms/pg/pg.py (Stolen from Fabian's Code (Public Version)) :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) See https://github.com/pytorch/pytorch/issues/29372 :param batch_size: Minibatch size :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param clip_range: Clipping parameter, it can be a function of the current progress remaining (from 1 to 0). :param clip_range_vf: Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. :param normalize_advantage: Whether to normalize or not the advantage :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update # 213 (cf https://github.com/hill-a/stable-baselines/issues/213) see issue By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param projection: What kind of Projection to use :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: Dict[str, Type[BasePolicy]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, } def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, batch_size: int = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, clip_range: Union[None, float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, action_coef: float = 0.0, max_grad_norm: Union[None, float] = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = {}, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", # Different from PPO: #projection: BaseProjectionLayer = KLProjectionLayer(), #projection: BaseProjectionLayer = WassersteinProjectionLayer(), #projection: BaseProjectionLayer = FrobeniusProjectionLayer(), projection: BaseProjectionLayer = BaseProjectionLayer(), _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate=learning_rate, n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs | {'sqrt_induced_gaussian': isinstance( projection, WassersteinProjectionLayer)}, verbose=verbose, device=device, create_eval_env=create_eval_env, seed=seed, _init_setup_model=False, supported_action_spaces=( spaces.Box, # spaces.Discrete, # spaces.MultiDiscrete, # spaces.MultiBinary, ), ) # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization if normalize_advantage: assert ( batch_size > 1 ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: # Check that `n_steps * n_envs > 1` to avoid NaN # when doing advantage normalization buffer_size = self.env.num_envs * self.n_steps assert ( buffer_size > 1 ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" # Check that the rollout buffer size is a multiple of the mini-batch size untruncated_batches = buffer_size // batch_size if buffer_size % batch_size > 0: warnings.warn( f"You have specified a mini-batch size of {batch_size}," f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," f" after every {untruncated_batches} untruncated mini-batches," f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" ) self.batch_size = batch_size self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage self.target_kl = target_kl self.action_coef = action_coef # Different from PPO: self.projection = projection self._global_steps = 0 if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() # Initialize schedules for policy/value clipping if self.clip_range is not None: self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" self.clip_range_vf = get_schedule_fn(self.clip_range_vf) def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range if self.clip_range: clip_range = self.clip_range(self._current_progress_remaining) else: clip_range = None # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf( self._current_progress_remaining) surrogate_losses = [] entropy_losses = [] trust_region_losses = [] action_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] continue_training = True # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): # This is new compared to PPO. # Calculating the TR-Projections we need to know the step number self._global_steps += 1 actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed if self.use_sde: self.policy.reset_noise(self.batch_size) # Different from PPO # TRL-Projection-Action: pol = self.policy features = pol.extract_features(rollout_data.observations) latent_pi, latent_vf = pol.mlp_extractor(features) p = pol._get_action_dist_from_latent(latent_pi) p_dist = p.distribution if isinstance(self.projection, WassersteinProjectionLayer): q_dist = new_dist_like_from_sqrt( p_dist, rollout_data.means, rollout_data.chols) else: q_dist = new_dist_like( p_dist, rollout_data.means, rollout_data.chols) proj_p = self.projection(p_dist, q_dist, self._global_steps) if isinstance(p_dist, th.distributions.Normal): # Normal uses a weird mapping from dimensions into batch_shape log_prob = proj_p.log_prob(actions).sum(dim=1) else: # UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping log_prob = proj_p.log_prob(actions) values = self.policy.value_net(latent_vf) entropy = proj_p.entropy() values = values.flatten() # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: advantages = (advantages - advantages.mean() ) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' # clipped surrogate loss if self.clip_range is None: surrogate_loss = -(advantages * ratio).mean() else: surrogate_loss_1 = advantages * ratio surrogate_loss_2 = advantages * \ th.clamp(ratio, 1 - clip_range, 1 + clip_range) surrogate_loss = - \ th.min(surrogate_loss_1, surrogate_loss_2).mean() surrogate_losses.append(surrogate_loss.item()) if clip_range is None: clip_fraction = 0 else: clip_fraction = th.mean( (th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: # Clip the different between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf ) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss trust_region_loss = self.projection.get_trust_region_loss( p, proj_p) trust_region_losses.append(trust_region_loss.item()) # 'Principle of least action' action_loss = th.mean(th.square(actions)) action_losses.append(action_loss) policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \ trust_region_loss + self.action_coef * action_loss pg_losses.append(policy_loss.item()) loss = policy_loss + self.vf_coef * value_loss # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob approx_kl_div = th.mean( (th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: print( f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") break # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm if self.max_grad_norm is not None: th.nn.utils.clip_grad_norm_( self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if not continue_training: break self._n_updates += self.n_epochs explained_var = explained_variance( self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs self.logger.record("train/surrogate_loss", np.mean(surrogate_losses)) self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/trust_region_loss", np.mean(trust_region_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/action_loss", np.mean(action_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): self.logger.record( "train/std", th.exp(self.policy.log_std).mean().item()) elif hasattr(self.policy, "chol"): chol = self.policy.chol if len(chol.shape) == 1: self.logger.record( "train/std", th.mean(chol).mean().item()) elif len(chol.shape) == 2: self.logger.record( "train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item()) else: self.logger.record( "train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") if self.clip_range is not None: self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf) def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "PPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, ) -> "PPO": return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, )