from typing import Any, Dict, Optional, Type, Union, NamedTuple import numpy as np import torch as th from gym import spaces from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.utils import obs_as_tensor # TRL requires the origina mean and covariance from the policy when the datapoint was created. # GaussianRolloutBuffer extends the RolloutBuffer by these two fields class GaussianRolloutBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor means: th.Tensor stds: th.Tensor class GaussianRolloutBuffer(RolloutBuffer): def __init__( self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, device: Union[th.device, str] = "cpu", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, cov_shape=None, ): self.means, self.stds = None, None # TODO: Correct shape for full cov matrix # self.action_space.shape + self.action_space.shape if cov_shape == None: cov_shape = action_space.shape self.cov_shape = cov_shape # It is ugly, but necessary to put this at the bottom of the init... super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) def reset(self) -> None: self.means = np.zeros( (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) self.stds = np.zeros( (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) super().reset() def add( self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, mean: th.Tensor, std: th.Tensor, ) -> None: """ :param obs: Observation :param action: Action :param reward: :param episode_start: Start of episode signal. :param value: estimated value of the current state following the current policy. :param log_prob: log probability of the action following the current policy. :param mean: Foo :param std: Bar """ if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space, spaces.Discrete): obs = obs.reshape((self.n_envs,) + self.obs_shape) self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() self.episode_starts[self.pos] = np.array(episode_start).copy() self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.means[self.pos] = mean.clone().cpu().numpy() self.stds[self.pos] = std.clone().cpu().numpy() self.pos += 1 if self.pos == self.buffer_size: self.full = True def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples: data = ( self.observations[batch_inds], self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), self.means[batch_inds].reshape((len(batch_inds), -1)), self.stds[batch_inds].reshape((len(batch_inds), -1)), ) return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data))) class GaussianRolloutCollectorAuxclass(): def _setup_model(self) -> None: super()._setup_model() self.rollout_buffer = GaussianRolloutBuffer( self.n_steps, self.observation_space, self.action_space, device=self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, ) def collect_rollouts( self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int, ) -> bool: """ Collect experiences using the current policy and fill a ``RolloutBuffer``. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param rollout_buffer: Buffer to fill with rollouts :param n_steps: Number of experiences to collect per environment :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) self.policy.set_training_mode(False) n_steps = 0 rollout_buffer.reset() # Sample new weights for the state dependent exploration if self.use_sde: self.policy.reset_noise(env.num_envs) callback.on_rollout_start() while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.policy.reset_noise(env.num_envs) with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) actions, values, log_probs = self.policy(obs_tensor) dist = self.policy.get_distribution(obs_tensor).distribution mean, std = dist.mean, dist.stddev actions = actions.cpu().numpy() # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.action_space, spaces.Box): clipped_actions = np.clip( actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) self.num_timesteps += env.num_envs # Give access to local variables callback.update_locals(locals()) if callback.on_step() is False: return False self._update_info_buffer(infos) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) # Handle timeout by bootstraping with value function # see GitHub issue #633 for idx, done in enumerate(dones): if ( done and infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False) ): terminal_obs = self.policy.obs_to_tensor( infos[idx]["terminal_observation"])[0] with th.no_grad(): terminal_value = self.policy.predict_values(terminal_obs)[ 0] rewards[idx] += self.gamma * terminal_value rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, mean, std) self._last_obs = new_obs self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep values = self.policy.predict_values( obs_as_tensor(new_obs, self.device)) rollout_buffer.compute_returns_and_advantage( last_values=values, dones=dones) callback.on_rollout_end() return True