Implemented TRLRolloutBuffer

This commit is contained in:
Dominik Moritz Roth 2022-06-25 21:47:39 +02:00
parent 60c954c8c1
commit b8488c531b

View File

@ -1,5 +1,5 @@
import warnings import warnings
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Optional, Type, Union, NamedTuple
import numpy as np import numpy as np
import torch as th import torch as th
@ -14,9 +14,10 @@ from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import obs_as_tensor from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.vec_env import VecNormalize
from ..projections.base_projection_layer import BaseProjectionLayer from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer # from ..projections.frob_projection_layer import FrobeniusProjectionLayer
class TRL_PG(OnPolicyAlgorithm): class TRL_PG(OnPolicyAlgorithm):
@ -56,7 +57,8 @@ class TRL_PG(OnPolicyAlgorithm):
Default: -1 (only sample at the beginning of the rollout) Default: -1 (only sample at the beginning of the rollout)
:param target_kl: Limit the KL divergence between updates, :param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) # 213 (cf https://github.com/hill-a/stable-baselines/issues/213)
see issue
By default, there is no limit on the kl div. By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging) :param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be :param create_eval_env: Whether to create a second environment that will be
@ -103,7 +105,7 @@ class TRL_PG(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto", device: Union[th.device, str] = "auto",
# Different from PPO: # Different from PPO:
projection: BaseProjectionLayer = BaseProjectionLayer, projection: BaseProjectionLayer = BaseProjectionLayer(),
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
@ -129,9 +131,9 @@ class TRL_PG(OnPolicyAlgorithm):
_init_setup_model=False, _init_setup_model=False,
supported_action_spaces=( supported_action_spaces=(
spaces.Box, spaces.Box,
spaces.Discrete, # spaces.Discrete,
spaces.MultiDiscrete, # spaces.MultiDiscrete,
spaces.MultiBinary, # spaces.MultiBinary,
), ),
) )
@ -185,6 +187,17 @@ class TRL_PG(OnPolicyAlgorithm):
self.clip_range_vf = get_schedule_fn(self.clip_range_vf) self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
# Changed from PPO: We need a bigger RolloutBuffer
self.rollout_buffer = TRLRolloutBuffer(
self.n_steps,
self.observation_space,
self.action_space,
device=self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
)
def train(self) -> None: def train(self) -> None:
""" """
Update policy using the currently gathered rollout buffer. Update policy using the currently gathered rollout buffer.
@ -237,24 +250,36 @@ class TRL_PG(OnPolicyAlgorithm):
# log_prob == new_pogpacs (i think) # log_prob == new_pogpacs (i think)
# src of evaluate_actions: # src of evaluate_actions:
# features = self.extract_features(obs) # pol = self.policy
# latent_pi, latent_vf = self.mlp_extractor(features) # features = pol.extract_features(rollout_data.observations)
# distribution = self._get_action_dist_from_latent(latent_pi) # latent_pi, latent_vf = pol.mlp_extractor(features)
# distribution = pol._get_action_dist_from_latent(latent_pi)
# log_prob = distribution.log_prob(actions) # log_prob = distribution.log_prob(actions)
# values = self.value_net(latent_vf) # values = pol.value_net(latent_vf)
# return values, log_prob, distribution.entropy() # return values, log_prob, distribution.entropy()
# entropy = distribution.entropy()
# here we go: # here we go:
pol = self.policy pol = self.policy
features = pol.extract_features(rollout_data.observations) features = pol.extract_features(rollout_data.observations)
latent_pi, latent_vf = pol.mlp_extractor(features) latent_pi, latent_vf = pol.mlp_extractor(features)
p = pol._get_action_dist_from_latent(latent_pi) p = pol._get_action_dist_from_latent(latent_pi)
b_q = rollout_data.mean, rollout_data.std p_dist = p.distribution
proj_p = self.projection(pol, p, b_q, self._global_step) # q_means = rollout_data.means
log_prob = proj_p.log_prob(actions) # if len(rollout_data.stds.shape) == 1: # only diag
# or log_prob = pol.log_probability(proj_p, actions) # q_stds = th.diag(rollout_data.stds)
values = self.value_net(latent_vf) # else:
entropy = proj_p.entropy() # or not... # q_stds = rollout_data.stds
# q_dist = th.distributions.MultivariateNormal(
# q_means, q_stds)
q_dist = th.distributions.Normal(
rollout_data.means, rollout_data.stds)
proj_p = self.projection(p_dist, q_dist, self._global_steps)
log_prob = proj_p.log_prob(actions).sum(dim=1)
values = self.policy.value_net(latent_vf)
entropy = proj_p.entropy()
# log_prob = p.log_prob(actions)
values = values.flatten() values = values.flatten()
# Normalize advantage # Normalize advantage
@ -304,8 +329,7 @@ class TRL_PG(OnPolicyAlgorithm):
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
trust_region_loss = self.projection.get_trust_region_loss( trust_region_loss = self.projection.get_trust_region_loss(
pol, p, proj_p) p, proj_p)
# NOTE to future-self: policy has a different interface then in orig TRL-impl.
trust_region_losses.append(trust_region_loss.item()) trust_region_losses.append(trust_region_loss.item())
@ -434,10 +458,7 @@ class TRL_PG(OnPolicyAlgorithm):
# Convert to pytorch tensor or to TensorDict # Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device) obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor) actions, values, log_probs = self.policy(obs_tensor)
dist = self.policy.get_distribution(obs_tensor) dist = self.policy.get_distribution(obs_tensor).distribution
# TODO: Enforce this requirement somwhere else...
assert isinstance(
dist, th.distributions.Normal), 'TRL is only implemented for Policys in a continuous action-space that is gauss-parametarized!'
mean, std = dist.mean, dist.stddev mean, std = dist.mean, dist.stddev
actions = actions.cpu().numpy() actions = actions.cpu().numpy()
@ -495,3 +516,97 @@ class TRL_PG(OnPolicyAlgorithm):
callback.on_rollout_end() callback.on_rollout_end()
return True return True
class TRLRolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
means: th.Tensor
stds: th.Tensor
class TRLRolloutBuffer(RolloutBuffer):
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space,
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
self.means, self.stds = None, None
def reset(self) -> None:
self.means = np.zeros(
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
self.stds = np.zeros(
# (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32)
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
super().reset()
def add(
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
mean: th.Tensor,
std: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
:param mean: Foo
:param std: Bar
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.means[self.pos] = mean.clone().cpu().numpy()
self.stds[self.pos] = std.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> TRLRolloutBufferSamples:
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
self.means[batch_inds].reshape((len(batch_inds), -1)),
self.stds[batch_inds].reshape((len(batch_inds), -1)),
)
return TRLRolloutBufferSamples(*tuple(map(self.to_torch, data)))