metastable-baselines2/metastable_baselines2/common/on_policy_algorithm.py

345 lines
15 KiB
Python

import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from .buffers import BetterDictRolloutBuffer, BetterRolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="BetterOnPolicyAlgorithm")
class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
"""
The base for On-Policy algorithms (ex: A2C/PPO).
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
Equivalent to classic advantage when set to 1.
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_pca: Whether to use PCA.
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
n_steps: int,
gamma: float,
gae_lambda: float,
ent_coef: float,
vf_coef: float,
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
use_pca: bool,
pca_is: bool,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
assert not (use_sde and use_pca)
self.use_pca = use_pca
assert not pca_is or use_pca
self.pca_is = pca_is
assert not rollout_buffer_class and not rollout_buffer_kwargs
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
n_steps=n_steps,
gamma=gamma,
gae_lambda=gae_lambda,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
#rollout_buffer_class = rollout_buffer_class,
#rollout_buffer_kwargs = rollout_buffer_kwargs,
# support_multi_env=True,
seed=seed,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
supported_action_spaces=supported_action_spaces,
monitor_wrapper=monitor_wrapper,
_init_setup_model=_init_setup_model
)
if 'dist_kwargs' not in self.policy_kwargs:
self.policy_kwargs['dist_kwargs'] = {}
self.policy_kwargs['dist_kwargs']['n_envs'] = self.env.num_envs if hasattr(self.env, 'num_envs') else 1
self.rollout_buffer_class = None
self.rollout_buffer_kwargs = {}
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = BetterDictRolloutBuffer
else:
self.rollout_buffer_class = BetterRolloutBuffer
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
use_pca=self.use_pca,
**self.policy_kwargs # pytype:disable=not-instantiable
)
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space,
self.action_space,
device=self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
full_cov=(self.use_pca and self.policy.action_dist.is_full()),
**self.rollout_buffer_kwargs,
)
self.policy = self.policy.to(self.device)
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_rollout_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
n_steps = 0
rollout_buffer.reset()
# Sample new weights for the state dependent exploration
if self.use_sde:
self.policy.reset_noise(env.num_envs)
callback.on_rollout_start()
while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs, distributions = self.policy(obs_tensor, conditioned_log_probs=self.pca_is)
actions = actions.cpu().numpy()
# Rescale and perform action
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, spaces.Box):
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
self.num_timesteps += env.num_envs
# Give access to local variables
callback.update_locals(locals())
if not callback.on_step():
return False
self._update_info_buffer(infos)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale if hasattr(distributions.distribution, 'scale') else distributions.distribution.scale_tril)
self._last_obs = new_obs
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.update_locals(locals())
callback.on_rollout_end()
return True
def train(self) -> None:
"""
Consume current rollout data and update policy parameters.
Implemented by individual algorithms.
"""
raise NotImplementedError
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
trajectory: th.Tensor = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:param trajectory: Past trajectory. Only required when using PCA.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
def learn(
self: SelfOnPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "OnPolicyAlgorithm",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOnPolicyAlgorithm:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.train()
callback.on_training_end()
return self
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []