From b8488c531b108884ddbf167dbf5fdb1b4fe2564d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 25 Jun 2022 21:47:39 +0200 Subject: [PATCH] Implemented TRLRolloutBuffer --- sb3_trl/trl_pg/trl_pg.py | 161 +++++++++++++++++++++++++++++++++------ 1 file changed, 138 insertions(+), 23 deletions(-) diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index 458c4e7..5a42136 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union, NamedTuple import numpy as np import torch as th @@ -14,9 +14,10 @@ from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.utils import obs_as_tensor +from stable_baselines3.common.vec_env import VecNormalize from ..projections.base_projection_layer import BaseProjectionLayer -from ..projections.frob_projection_layer import FrobeniusProjectionLayer +# from ..projections.frob_projection_layer import FrobeniusProjectionLayer class TRL_PG(OnPolicyAlgorithm): @@ -56,7 +57,8 @@ class TRL_PG(OnPolicyAlgorithm): Default: -1 (only sample at the beginning of the rollout) :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update - see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + # 213 (cf https://github.com/hill-a/stable-baselines/issues/213) + see issue By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be @@ -103,7 +105,7 @@ class TRL_PG(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", # Different from PPO: - projection: BaseProjectionLayer = BaseProjectionLayer, + projection: BaseProjectionLayer = BaseProjectionLayer(), _init_setup_model: bool = True, ): @@ -129,9 +131,9 @@ class TRL_PG(OnPolicyAlgorithm): _init_setup_model=False, supported_action_spaces=( spaces.Box, - spaces.Discrete, - spaces.MultiDiscrete, - spaces.MultiBinary, + # spaces.Discrete, + # spaces.MultiDiscrete, + # spaces.MultiBinary, ), ) @@ -185,6 +187,17 @@ class TRL_PG(OnPolicyAlgorithm): self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + # Changed from PPO: We need a bigger RolloutBuffer + self.rollout_buffer = TRLRolloutBuffer( + self.n_steps, + self.observation_space, + self.action_space, + device=self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + def train(self) -> None: """ Update policy using the currently gathered rollout buffer. @@ -237,24 +250,36 @@ class TRL_PG(OnPolicyAlgorithm): # log_prob == new_pogpacs (i think) # src of evaluate_actions: - # features = self.extract_features(obs) - # latent_pi, latent_vf = self.mlp_extractor(features) - # distribution = self._get_action_dist_from_latent(latent_pi) + # pol = self.policy + # features = pol.extract_features(rollout_data.observations) + # latent_pi, latent_vf = pol.mlp_extractor(features) + # distribution = pol._get_action_dist_from_latent(latent_pi) # log_prob = distribution.log_prob(actions) - # values = self.value_net(latent_vf) + # values = pol.value_net(latent_vf) # return values, log_prob, distribution.entropy() + # entropy = distribution.entropy() # here we go: pol = self.policy features = pol.extract_features(rollout_data.observations) latent_pi, latent_vf = pol.mlp_extractor(features) p = pol._get_action_dist_from_latent(latent_pi) - b_q = rollout_data.mean, rollout_data.std - proj_p = self.projection(pol, p, b_q, self._global_step) - log_prob = proj_p.log_prob(actions) - # or log_prob = pol.log_probability(proj_p, actions) - values = self.value_net(latent_vf) - entropy = proj_p.entropy() # or not... + p_dist = p.distribution + # q_means = rollout_data.means + # if len(rollout_data.stds.shape) == 1: # only diag + # q_stds = th.diag(rollout_data.stds) + # else: + # q_stds = rollout_data.stds + # q_dist = th.distributions.MultivariateNormal( + # q_means, q_stds) + q_dist = th.distributions.Normal( + rollout_data.means, rollout_data.stds) + proj_p = self.projection(p_dist, q_dist, self._global_steps) + log_prob = proj_p.log_prob(actions).sum(dim=1) + values = self.policy.value_net(latent_vf) + entropy = proj_p.entropy() + + # log_prob = p.log_prob(actions) values = values.flatten() # Normalize advantage @@ -304,8 +329,7 @@ class TRL_PG(OnPolicyAlgorithm): # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss trust_region_loss = self.projection.get_trust_region_loss( - pol, p, proj_p) - # NOTE to future-self: policy has a different interface then in orig TRL-impl. + p, proj_p) trust_region_losses.append(trust_region_loss.item()) @@ -434,10 +458,7 @@ class TRL_PG(OnPolicyAlgorithm): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) actions, values, log_probs = self.policy(obs_tensor) - dist = self.policy.get_distribution(obs_tensor) - # TODO: Enforce this requirement somwhere else... - assert isinstance( - dist, th.distributions.Normal), 'TRL is only implemented for Policys in a continuous action-space that is gauss-parametarized!' + dist = self.policy.get_distribution(obs_tensor).distribution mean, std = dist.mean, dist.stddev actions = actions.cpu().numpy() @@ -495,3 +516,97 @@ class TRL_PG(OnPolicyAlgorithm): callback.on_rollout_end() return True + + +class TRLRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + means: th.Tensor + stds: th.Tensor + + +class TRLRolloutBuffer(RolloutBuffer): + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super().__init__(buffer_size, observation_space, action_space, + device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) + self.means, self.stds = None, None + + def reset(self) -> None: + self.means = np.zeros( + (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) + self.stds = np.zeros( + # (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32) + (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) + super().reset() + + def add( + self, + obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + mean: th.Tensor, + std: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + :param mean: Foo + :param std: Bar + """ + + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs,) + self.obs_shape) + + self.observations[self.pos] = np.array(obs).copy() + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.means[self.pos] = mean.clone().cpu().numpy() + self.stds[self.pos] = std.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> TRLRolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + self.means[batch_inds].reshape((len(batch_inds), -1)), + self.stds[batch_inds].reshape((len(batch_inds), -1)), + ) + return TRLRolloutBufferSamples(*tuple(map(self.to_torch, data)))