diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index 45badff..d286896 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -17,7 +17,9 @@ from stable_baselines3.common.utils import obs_as_tensor from stable_baselines3.common.vec_env import VecNormalize from ..projections.base_projection_layer import BaseProjectionLayer -# from ..projections.frob_projection_layer import FrobeniusProjectionLayer +from ..projections.frob_projection_layer import FrobeniusProjectionLayer + +from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples class TRL_PG(OnPolicyAlgorithm): @@ -105,7 +107,8 @@ class TRL_PG(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", # Different from PPO: - projection: BaseProjectionLayer = BaseProjectionLayer(), + projection: BaseProjectionLayer = FrobeniusProjectionLayer(), + #projection: BaseProjectionLayer = BaseProjectionLayer(), _init_setup_model: bool = True, ): @@ -188,7 +191,7 @@ class TRL_PG(OnPolicyAlgorithm): self.clip_range_vf = get_schedule_fn(self.clip_range_vf) # Changed from PPO: We need a bigger RolloutBuffer - self.rollout_buffer = TRLRolloutBuffer( + self.rollout_buffer = GaussianRolloutBuffer( self.n_steps, self.observation_space, self.action_space, @@ -513,97 +516,3 @@ class TRL_PG(OnPolicyAlgorithm): callback.on_rollout_end() return True - - -class TRLRolloutBufferSamples(NamedTuple): - observations: th.Tensor - actions: th.Tensor - old_values: th.Tensor - old_log_prob: th.Tensor - advantages: th.Tensor - returns: th.Tensor - means: th.Tensor - stds: th.Tensor - - -class TRLRolloutBuffer(RolloutBuffer): - def __init__( - self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[th.device, str] = "cpu", - gae_lambda: float = 1, - gamma: float = 0.99, - n_envs: int = 1, - ): - - super().__init__(buffer_size, observation_space, action_space, - device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma) - self.means, self.stds = None, None - - def reset(self) -> None: - self.means = np.zeros( - (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) - self.stds = np.zeros( - # (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32) - (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) - super().reset() - - def add( - self, - obs: np.ndarray, - action: np.ndarray, - reward: np.ndarray, - episode_start: np.ndarray, - value: th.Tensor, - log_prob: th.Tensor, - mean: th.Tensor, - std: th.Tensor, - ) -> None: - """ - :param obs: Observation - :param action: Action - :param reward: - :param episode_start: Start of episode signal. - :param value: estimated value of the current state - following the current policy. - :param log_prob: log probability of the action - following the current policy. - :param mean: Foo - :param std: Bar - """ - - if len(log_prob.shape) == 0: - # Reshape 0-d tensor to avoid error - log_prob = log_prob.reshape(-1, 1) - - # Reshape needed when using multiple envs with discrete observations - # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) - if isinstance(self.observation_space, spaces.Discrete): - obs = obs.reshape((self.n_envs,) + self.obs_shape) - - self.observations[self.pos] = np.array(obs).copy() - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.episode_starts[self.pos] = np.array(episode_start).copy() - self.values[self.pos] = value.clone().cpu().numpy().flatten() - self.log_probs[self.pos] = log_prob.clone().cpu().numpy() - self.means[self.pos] = mean.clone().cpu().numpy() - self.stds[self.pos] = std.clone().cpu().numpy() - self.pos += 1 - if self.pos == self.buffer_size: - self.full = True - - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> TRLRolloutBufferSamples: - data = ( - self.observations[batch_inds], - self.actions[batch_inds], - self.values[batch_inds].flatten(), - self.log_probs[batch_inds].flatten(), - self.advantages[batch_inds].flatten(), - self.returns[batch_inds].flatten(), - self.means[batch_inds].reshape((len(batch_inds), -1)), - self.stds[batch_inds].reshape((len(batch_inds), -1)), - ) - return TRLRolloutBufferSamples(*tuple(map(self.to_torch, data)))