231 lines
8.5 KiB
Python
231 lines
8.5 KiB
Python
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gym import spaces
|
|
|
|
from stable_baselines3.common.buffers import RolloutBuffer
|
|
from stable_baselines3.common.vec_env import VecNormalize
|
|
from stable_baselines3.common.vec_env import VecEnv
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from stable_baselines3.common.utils import obs_as_tensor
|
|
|
|
|
|
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
|
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
|
|
|
|
|
class GaussianRolloutBufferSamples(NamedTuple):
|
|
observations: th.Tensor
|
|
actions: th.Tensor
|
|
old_values: th.Tensor
|
|
old_log_prob: th.Tensor
|
|
advantages: th.Tensor
|
|
returns: th.Tensor
|
|
means: th.Tensor
|
|
stds: th.Tensor
|
|
|
|
|
|
class GaussianRolloutBuffer(RolloutBuffer):
|
|
def __init__(
|
|
self,
|
|
buffer_size: int,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: Union[th.device, str] = "cpu",
|
|
gae_lambda: float = 1,
|
|
gamma: float = 0.99,
|
|
n_envs: int = 1,
|
|
cov_shape=None,
|
|
):
|
|
self.means, self.stds = None, None
|
|
# TODO: Correct shape for full cov matrix
|
|
# self.action_space.shape + self.action_space.shape
|
|
|
|
if cov_shape == None:
|
|
cov_shape = action_space.shape
|
|
self.cov_shape = cov_shape
|
|
|
|
# It is ugly, but necessary to put this at the bottom of the init...
|
|
super().__init__(buffer_size, observation_space, action_space,
|
|
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
|
|
|
|
def reset(self) -> None:
|
|
self.means = np.zeros(
|
|
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
|
self.stds = np.zeros(
|
|
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
|
super().reset()
|
|
|
|
def add(
|
|
self,
|
|
obs: np.ndarray,
|
|
action: np.ndarray,
|
|
reward: np.ndarray,
|
|
episode_start: np.ndarray,
|
|
value: th.Tensor,
|
|
log_prob: th.Tensor,
|
|
mean: th.Tensor,
|
|
std: th.Tensor,
|
|
) -> None:
|
|
"""
|
|
:param obs: Observation
|
|
:param action: Action
|
|
:param reward:
|
|
:param episode_start: Start of episode signal.
|
|
:param value: estimated value of the current state
|
|
following the current policy.
|
|
:param log_prob: log probability of the action
|
|
following the current policy.
|
|
:param mean: Foo
|
|
:param std: Bar
|
|
"""
|
|
|
|
if len(log_prob.shape) == 0:
|
|
# Reshape 0-d tensor to avoid error
|
|
log_prob = log_prob.reshape(-1, 1)
|
|
|
|
# Reshape needed when using multiple envs with discrete observations
|
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
|
if isinstance(self.observation_space, spaces.Discrete):
|
|
obs = obs.reshape((self.n_envs,) + self.obs_shape)
|
|
|
|
self.observations[self.pos] = np.array(obs).copy()
|
|
self.actions[self.pos] = np.array(action).copy()
|
|
self.rewards[self.pos] = np.array(reward).copy()
|
|
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
|
self.means[self.pos] = mean.clone().cpu().numpy()
|
|
self.stds[self.pos] = std.clone().cpu().numpy()
|
|
self.pos += 1
|
|
if self.pos == self.buffer_size:
|
|
self.full = True
|
|
|
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GaussianRolloutBufferSamples:
|
|
data = (
|
|
self.observations[batch_inds],
|
|
self.actions[batch_inds],
|
|
self.values[batch_inds].flatten(),
|
|
self.log_probs[batch_inds].flatten(),
|
|
self.advantages[batch_inds].flatten(),
|
|
self.returns[batch_inds].flatten(),
|
|
self.means[batch_inds].reshape((len(batch_inds), -1)),
|
|
self.stds[batch_inds].reshape((len(batch_inds), -1)),
|
|
)
|
|
return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
|
|
|
|
|
class GaussianRolloutCollectorAuxclass():
|
|
def _setup_model(self) -> None:
|
|
super()._setup_model()
|
|
|
|
self.rollout_buffer = GaussianRolloutBuffer(
|
|
self.n_steps,
|
|
self.observation_space,
|
|
self.action_space,
|
|
device=self.device,
|
|
gamma=self.gamma,
|
|
gae_lambda=self.gae_lambda,
|
|
n_envs=self.n_envs,
|
|
)
|
|
|
|
def collect_rollouts(
|
|
self,
|
|
env: VecEnv,
|
|
callback: BaseCallback,
|
|
rollout_buffer: RolloutBuffer,
|
|
n_rollout_steps: int,
|
|
) -> bool:
|
|
"""
|
|
Collect experiences using the current policy and fill a ``RolloutBuffer``.
|
|
The term rollout here refers to the model-free notion and should not
|
|
be used with the concept of rollout used in model-based RL or planning.
|
|
:param env: The training environment
|
|
:param callback: Callback that will be called at each step
|
|
(and at the beginning and end of the rollout)
|
|
:param rollout_buffer: Buffer to fill with rollouts
|
|
:param n_steps: Number of experiences to collect per environment
|
|
:return: True if function returned with at least `n_rollout_steps`
|
|
collected, False if callback terminated rollout prematurely.
|
|
"""
|
|
assert self._last_obs is not None, "No previous observation was provided"
|
|
# Switch to eval mode (this affects batch norm / dropout)
|
|
self.policy.set_training_mode(False)
|
|
|
|
n_steps = 0
|
|
rollout_buffer.reset()
|
|
# Sample new weights for the state dependent exploration
|
|
if self.use_sde:
|
|
self.policy.reset_noise(env.num_envs)
|
|
|
|
callback.on_rollout_start()
|
|
|
|
while n_steps < n_rollout_steps:
|
|
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
|
# Sample a new noise matrix
|
|
self.policy.reset_noise(env.num_envs)
|
|
|
|
with th.no_grad():
|
|
# Convert to pytorch tensor or to TensorDict
|
|
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
|
actions, values, log_probs = self.policy(obs_tensor)
|
|
dist = self.policy.get_distribution(obs_tensor).distribution
|
|
mean, std = dist.mean, dist.stddev
|
|
actions = actions.cpu().numpy()
|
|
|
|
# Rescale and perform action
|
|
clipped_actions = actions
|
|
# Clip the actions to avoid out of bound error
|
|
if isinstance(self.action_space, spaces.Box):
|
|
clipped_actions = np.clip(
|
|
actions, self.action_space.low, self.action_space.high)
|
|
|
|
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
|
|
|
self.num_timesteps += env.num_envs
|
|
|
|
# Give access to local variables
|
|
callback.update_locals(locals())
|
|
if callback.on_step() is False:
|
|
return False
|
|
|
|
self._update_info_buffer(infos)
|
|
n_steps += 1
|
|
|
|
if isinstance(self.action_space, spaces.Discrete):
|
|
# Reshape in case of discrete action
|
|
actions = actions.reshape(-1, 1)
|
|
|
|
# Handle timeout by bootstraping with value function
|
|
# see GitHub issue #633
|
|
for idx, done in enumerate(dones):
|
|
if (
|
|
done
|
|
and infos[idx].get("terminal_observation") is not None
|
|
and infos[idx].get("TimeLimit.truncated", False)
|
|
):
|
|
terminal_obs = self.policy.obs_to_tensor(
|
|
infos[idx]["terminal_observation"])[0]
|
|
with th.no_grad():
|
|
terminal_value = self.policy.predict_values(terminal_obs)[
|
|
0]
|
|
rewards[idx] += self.gamma * terminal_value
|
|
|
|
rollout_buffer.add(self._last_obs, actions, rewards,
|
|
self._last_episode_starts, values, log_probs, mean, std)
|
|
self._last_obs = new_obs
|
|
self._last_episode_starts = dones
|
|
|
|
with th.no_grad():
|
|
# Compute value for the last timestep
|
|
values = self.policy.predict_values(
|
|
obs_as_tensor(new_obs, self.device))
|
|
|
|
rollout_buffer.compute_returns_and_advantage(
|
|
last_values=values, dones=dones)
|
|
|
|
callback.on_rollout_end()
|
|
|
|
return True
|