441 lines
20 KiB
Python
441 lines
20 KiB
Python
import warnings
|
|
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gym import spaces
|
|
from torch.nn import functional as F
|
|
|
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
|
from stable_baselines3.common.vec_env import VecEnv
|
|
from stable_baselines3.common.buffers import RolloutBuffer
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from stable_baselines3.common.utils import obs_as_tensor
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
|
|
from ..misc.distTools import new_dist_like, new_dist_like_from_sqrt
|
|
|
|
from metastable_projections.projections.base_projection_layer import BaseProjectionLayer
|
|
from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer
|
|
from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer
|
|
from metastable_projections.projections.kl_projection_layer import KLProjectionLayer
|
|
|
|
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|
"""
|
|
Differential Trust Region Layer (TRL) for Policy Gradient (PG)
|
|
|
|
Paper: https://arxiv.org/abs/2101.09207
|
|
Code: This implementation borrows (/steals most) code from SB3's PPO implementation https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py
|
|
The implementation of the TRL-specific parts borrows from https://github.com/boschresearch/trust-region-layers/blob/main/trust_region_projections/algorithms/pg/pg.py (Stolen from Fabian's Code (Public Version))
|
|
|
|
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
|
:param env: The environment to learn from (if registered in Gym, can be str)
|
|
:param learning_rate: The learning rate, it can be a function
|
|
of the current progress remaining (from 1 to 0)
|
|
:param n_steps: The number of steps to run for each environment per update
|
|
(i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
|
|
NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
|
|
See https://github.com/pytorch/pytorch/issues/29372
|
|
:param batch_size: Minibatch size
|
|
:param n_epochs: Number of epoch when optimizing the surrogate loss
|
|
:param gamma: Discount factor
|
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
|
:param clip_range: Clipping parameter, it can be a function of the current progress
|
|
remaining (from 1 to 0).
|
|
:param clip_range_vf: Clipping parameter for the value function,
|
|
it can be a function of the current progress remaining (from 1 to 0).
|
|
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
|
no clipping will be done on the value function.
|
|
IMPORTANT: this clipping depends on the reward scaling.
|
|
:param normalize_advantage: Whether to normalize or not the advantage
|
|
:param ent_coef: Entropy coefficient for the loss calculation
|
|
:param vf_coef: Value function coefficient for the loss calculation
|
|
:param max_grad_norm: The maximum value for the gradient clipping
|
|
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
|
instead of action noise exploration (default: False)
|
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
|
Default: -1 (only sample at the beginning of the rollout)
|
|
:param target_kl: Limit the KL divergence between updates,
|
|
because the clipping is not enough to prevent large update
|
|
# 213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
|
see issue
|
|
By default, there is no limit on the kl div.
|
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
|
:param create_eval_env: Whether to create a second environment that will be
|
|
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
|
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
|
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
|
:param seed: Seed for the pseudo random generators
|
|
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
|
Setting it to auto, the code will be run on the GPU if possible.
|
|
:param projection: What kind of Projection to use
|
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
|
"""
|
|
|
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
|
"MlpPolicy": ActorCriticPolicy,
|
|
"CnnPolicy": ActorCriticCnnPolicy,
|
|
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
policy: Union[str, Type[ActorCriticPolicy]],
|
|
env: Union[GymEnv, str],
|
|
learning_rate: Union[float, Schedule] = 3e-4,
|
|
n_steps: int = 2048,
|
|
batch_size: int = 64,
|
|
n_epochs: int = 10,
|
|
gamma: float = 0.99,
|
|
gae_lambda: float = 0.95,
|
|
clip_range: Union[None, float, Schedule] = 0.2,
|
|
clip_range_vf: Union[None, float, Schedule] = None,
|
|
normalize_advantage: bool = True,
|
|
ent_coef: float = 0.0,
|
|
vf_coef: float = 0.5,
|
|
action_coef: float = 0.0,
|
|
max_grad_norm: Union[None, float] = 0.5,
|
|
use_sde: bool = False,
|
|
sde_sample_freq: int = -1,
|
|
target_kl: Optional[float] = None,
|
|
tensorboard_log: Optional[str] = None,
|
|
policy_kwargs: Optional[Dict[str, Any]] = {},
|
|
verbose: int = 0,
|
|
seed: Optional[int] = None,
|
|
device: Union[th.device, str] = "auto",
|
|
|
|
# Different from PPO:
|
|
#projection: BaseProjectionLayer = KLProjectionLayer(),
|
|
#projection: BaseProjectionLayer = WassersteinProjectionLayer(),
|
|
#projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
|
|
projection: BaseProjectionLayer = BaseProjectionLayer(),
|
|
|
|
|
|
_init_setup_model: bool = True,
|
|
):
|
|
|
|
super().__init__(
|
|
policy,
|
|
env,
|
|
learning_rate=learning_rate,
|
|
n_steps=n_steps,
|
|
gamma=gamma,
|
|
gae_lambda=gae_lambda,
|
|
ent_coef=ent_coef,
|
|
vf_coef=vf_coef,
|
|
max_grad_norm=max_grad_norm,
|
|
use_sde=use_sde,
|
|
sde_sample_freq=sde_sample_freq,
|
|
tensorboard_log=tensorboard_log,
|
|
policy_kwargs=policy_kwargs |
|
|
{'sqrt_induced_gaussian': isinstance(
|
|
projection, WassersteinProjectionLayer)},
|
|
verbose=verbose,
|
|
device=device,
|
|
seed=seed,
|
|
_init_setup_model=False,
|
|
supported_action_spaces=(
|
|
spaces.Box,
|
|
# spaces.Discrete,
|
|
# spaces.MultiDiscrete,
|
|
# spaces.MultiBinary,
|
|
),
|
|
)
|
|
|
|
# Sanity check, otherwise it will lead to noisy gradient and NaN
|
|
# because of the advantage normalization
|
|
if normalize_advantage:
|
|
assert (
|
|
batch_size > 1
|
|
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
|
|
|
|
if self.env is not None:
|
|
# Check that `n_steps * n_envs > 1` to avoid NaN
|
|
# when doing advantage normalization
|
|
buffer_size = self.env.num_envs * self.n_steps
|
|
assert (
|
|
buffer_size > 1
|
|
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
|
|
# Check that the rollout buffer size is a multiple of the mini-batch size
|
|
untruncated_batches = buffer_size // batch_size
|
|
if buffer_size % batch_size > 0:
|
|
warnings.warn(
|
|
f"You have specified a mini-batch size of {batch_size},"
|
|
f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
|
|
f" after every {untruncated_batches} untruncated mini-batches,"
|
|
f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
|
|
f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
|
|
f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
|
|
)
|
|
self.batch_size = batch_size
|
|
self.n_epochs = n_epochs
|
|
self.clip_range = clip_range
|
|
self.clip_range_vf = clip_range_vf
|
|
self.normalize_advantage = normalize_advantage
|
|
self.target_kl = target_kl
|
|
|
|
self.action_coef = action_coef
|
|
|
|
# Different from PPO:
|
|
self.projection = projection
|
|
self._global_steps = 0
|
|
|
|
if _init_setup_model:
|
|
self._setup_model()
|
|
|
|
def _setup_model(self) -> None:
|
|
super()._setup_model()
|
|
|
|
# Initialize schedules for policy/value clipping
|
|
if self.clip_range is not None:
|
|
self.clip_range = get_schedule_fn(self.clip_range)
|
|
if self.clip_range_vf is not None:
|
|
if isinstance(self.clip_range_vf, (float, int)):
|
|
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
|
|
|
|
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
|
|
|
def train(self) -> None:
|
|
"""
|
|
Update policy using the currently gathered rollout buffer.
|
|
"""
|
|
# Switch to train mode (this affects batch norm / dropout)
|
|
self.policy.set_training_mode(True)
|
|
# Update optimizer learning rate
|
|
self._update_learning_rate(self.policy.optimizer)
|
|
# Compute current clip range
|
|
if self.clip_range:
|
|
clip_range = self.clip_range(self._current_progress_remaining)
|
|
else:
|
|
clip_range = None
|
|
# Optional: clip range for the value function
|
|
if self.clip_range_vf is not None:
|
|
clip_range_vf = self.clip_range_vf(
|
|
self._current_progress_remaining)
|
|
|
|
surrogate_losses = []
|
|
entropy_losses = []
|
|
trust_region_losses = []
|
|
action_losses = []
|
|
pg_losses, value_losses = [], []
|
|
clip_fractions = []
|
|
|
|
setbackCtr = 0
|
|
bak = deepcopy(self.policy.state_dict())
|
|
|
|
continue_training = True
|
|
|
|
# train for n_epochs epochs
|
|
for epoch in range(self.n_epochs):
|
|
# self.policy.load_state_dict(
|
|
approx_kl_divs = []
|
|
# Do a complete pass on the rollout buffer
|
|
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
|
# This is new compared to PPO.
|
|
# Calculating the TR-Projections we need to know the step number
|
|
self._global_steps += 1
|
|
|
|
actions = rollout_data.actions
|
|
if isinstance(self.action_space, spaces.Discrete):
|
|
# Convert discrete action from float to long
|
|
actions = rollout_data.actions.long().flatten()
|
|
|
|
# Re-sample the noise matrix because the log_std has changed
|
|
if self.use_sde:
|
|
self.policy.reset_noise(self.batch_size)
|
|
|
|
# Different from PPO
|
|
# TRL-Projection-Action:
|
|
pol = self.policy
|
|
features = pol.extract_features(rollout_data.observations)
|
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
|
try:
|
|
p = pol._get_action_dist_from_latent(latent_pi)
|
|
except ValueError:
|
|
self.policy.load_state_dict(bak)
|
|
setbackCtr += 1
|
|
print(
|
|
'[!] Gradients Exploded; reseting to last known states (setback number '+str(setbackCtr)+')')
|
|
break
|
|
del bak
|
|
bak = deepcopy(self.policy.state_dict())
|
|
p_dist = p.distribution
|
|
if isinstance(self.projection, WassersteinProjectionLayer):
|
|
q_dist = new_dist_like_from_sqrt(
|
|
p_dist, rollout_data.means, rollout_data.chols)
|
|
else:
|
|
q_dist = new_dist_like(
|
|
p_dist, rollout_data.means, rollout_data.chols)
|
|
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
|
if isinstance(p_dist, th.distributions.Normal):
|
|
# Normal uses a weird mapping from dimensions into batch_shape
|
|
log_prob = proj_p.log_prob(actions).sum(dim=1)
|
|
else:
|
|
# UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping
|
|
log_prob = proj_p.log_prob(actions)
|
|
values = self.policy.value_net(latent_vf)
|
|
entropy = proj_p.entropy()
|
|
|
|
values = values.flatten()
|
|
# Normalize advantage
|
|
advantages = rollout_data.advantages
|
|
if self.normalize_advantage:
|
|
advantages = (advantages - advantages.mean()
|
|
) / (advantages.std() + 1e-8)
|
|
|
|
# ratio between old and new policy, should be one at the first iteration
|
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
|
|
|
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
|
# clipped surrogate loss
|
|
if self.clip_range is None:
|
|
surrogate_loss = -(advantages * ratio).mean()
|
|
else:
|
|
surrogate_loss_1 = advantages * ratio
|
|
surrogate_loss_2 = advantages * \
|
|
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
|
surrogate_loss = - \
|
|
th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
|
|
|
surrogate_losses.append(surrogate_loss.item())
|
|
|
|
if clip_range is None:
|
|
clip_fraction = 0
|
|
else:
|
|
clip_fraction = th.mean(
|
|
(th.abs(ratio - 1) > clip_range).float()).item()
|
|
clip_fractions.append(clip_fraction)
|
|
|
|
if self.clip_range_vf is None:
|
|
# No clipping
|
|
values_pred = values
|
|
else:
|
|
# Clip the different between old and new value
|
|
# NOTE: this depends on the reward scaling
|
|
values_pred = rollout_data.old_values + th.clamp(
|
|
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
|
|
)
|
|
# Value loss using the TD(gae_lambda) target
|
|
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
|
value_losses.append(value_loss.item())
|
|
|
|
# Entropy loss favor exploration
|
|
if entropy is None:
|
|
# Approximate entropy when no analytical form
|
|
entropy_loss = -th.mean(-log_prob)
|
|
else:
|
|
entropy_loss = -th.mean(entropy)
|
|
|
|
entropy_losses.append(entropy_loss.item())
|
|
|
|
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
|
trust_region_loss = self.projection.get_trust_region_loss(
|
|
p, proj_p)
|
|
|
|
trust_region_losses.append(trust_region_loss.item())
|
|
|
|
# 'Principle of least action'
|
|
action_loss = th.mean(th.square(actions))
|
|
|
|
action_losses.append(action_loss.item())
|
|
|
|
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
|
|
trust_region_loss + self.action_coef * action_loss
|
|
|
|
pg_losses.append(policy_loss.item())
|
|
|
|
loss = policy_loss + self.vf_coef * value_loss
|
|
|
|
# Calculate approximate form of reverse KL Divergence for early stopping
|
|
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
|
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
|
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
|
with th.no_grad():
|
|
log_ratio = log_prob - rollout_data.old_log_prob
|
|
approx_kl_div = th.mean(
|
|
(th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
|
approx_kl_divs.append(approx_kl_div)
|
|
|
|
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
|
continue_training = False
|
|
if self.verbose >= 1:
|
|
print(
|
|
f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
|
break
|
|
|
|
# Optimization step
|
|
self.policy.optimizer.zero_grad()
|
|
loss.backward()
|
|
# Clip grad norm
|
|
if self.max_grad_norm is not None:
|
|
th.nn.utils.clip_grad_norm_(
|
|
self.policy.parameters(), self.max_grad_norm)
|
|
self.policy.optimizer.step()
|
|
|
|
if not continue_training:
|
|
break
|
|
|
|
self._n_updates += self.n_epochs
|
|
explained_var = explained_variance(
|
|
self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
|
|
|
# Logs
|
|
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
|
|
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
|
self.logger.record("train/trust_region_loss",
|
|
np.mean(trust_region_losses))
|
|
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
|
self.logger.record("train/action_loss", np.mean(action_losses))
|
|
self.logger.record("train/value_loss", np.mean(value_losses))
|
|
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
|
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
|
self.logger.record("train/loss", loss.item())
|
|
self.logger.record("train/explained_variance", explained_var)
|
|
self.logger.record("train/ssf", self.sde_sample_freq)
|
|
if hasattr(self.policy, "log_std"):
|
|
self.logger.record(
|
|
"train/std", th.exp(self.policy.log_std).mean().item())
|
|
elif hasattr(self.policy, "chol"):
|
|
chol = self.policy.chol
|
|
if len(chol.shape) == 1:
|
|
self.logger.record(
|
|
"train/std", th.mean(chol).mean().item())
|
|
elif len(chol.shape) == 2:
|
|
self.logger.record(
|
|
"train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item())
|
|
else:
|
|
self.logger.record(
|
|
"train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item())
|
|
|
|
self.logger.record("train/n_updates",
|
|
self._n_updates, exclude="tensorboard")
|
|
if self.clip_range is not None:
|
|
self.logger.record("train/clip_range", clip_range)
|
|
if self.clip_range_vf is not None:
|
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
|
|
|
def learn(
|
|
self,
|
|
total_timesteps: int,
|
|
callback: MaybeCallback = None,
|
|
log_interval: int = 1,
|
|
tb_log_name: str = "PPO",
|
|
reset_num_timesteps: bool = True,
|
|
) -> "PPO":
|
|
|
|
return super().learn(
|
|
total_timesteps=total_timesteps,
|
|
callback=callback,
|
|
log_interval=log_interval,
|
|
tb_log_name=tb_log_name,
|
|
reset_num_timesteps=reset_num_timesteps,
|
|
)
|