diff --git a/metastable_baselines2/common/buffers.py b/metastable_baselines2/common/buffers.py index 500167f..94e807c 100644 --- a/metastable_baselines2/common/buffers.py +++ b/metastable_baselines2/common/buffers.py @@ -29,7 +29,7 @@ class BetterRolloutBufferSamples(NamedTuple): old_values: th.Tensor old_log_prob: th.Tensor mean: th.Tensor - cov_Decomp: th.Tensor + cov_decomp: th.Tensor advantages: th.Tensor returns: th.Tensor @@ -40,7 +40,7 @@ class BetterDictRolloutBufferSamples(NamedTuple): old_values: th.Tensor old_log_prob: th.Tensor mean: th.Tensor - cov_Decomp: th.Tensor + cov_decomp: th.Tensor advantages: th.Tensor returns: th.Tensor @@ -227,8 +227,8 @@ class BetterRolloutBuffer(RolloutBuffer): self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), - self.means[batch_inds].flatten(), - self.cov_decomps[batch_inds].flatten(), + np.squeeze(self.means[batch_inds], axis=1), + np.squeeze(self.cov_decomps[batch_inds], axis=1), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), ) diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 539e38e..edd3be8 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -39,28 +39,31 @@ class TRPL(BetterOnPolicyAlgorithm): :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - :param clip_range: Clipping parameter, it can be a function of the current progress - remaining (from 1 to 0). - :param clip_range_vf: Clipping parameter for the value function, - it can be a function of the current progress remaining (from 1 to 0). + :param clip_range: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO. + Clipping parameter, it can be a function of the current progress remaining (from 1 to 0). + Setting it to None will result in no clipping (default) + :param clip_range_vf: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO. + Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. - :param normalize_advantage: Whether to normalize or not the advantage + :param normalize_advantage: Normally a good idea; but TRPL actually often works better without normalization of the advantage. + Whether to normalize or not the advantage. (Default: False) :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation - :param max_grad_norm: The maximum value for the gradient clipping + :param max_grad_norm: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO.. + The maximum value for the gradient clipping. Setting it to None will result in no gradient clipping (default) :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param use_pca: Wether to use Prior Conditioned Annealing + :param use_pca: Wether to use Prior Conditioned Annealing. :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation - :param target_kl: Limit the KL divergence between updates, - because the clipping is not enough to prevent large update + :param target_kl: Not part of reference implementation of TRPL, but we still ported it over from sb3's PPO. + Default to None: No limit. + Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) - By default, there is no limit on the kl div. :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) @@ -89,7 +92,7 @@ class TRPL(BetterOnPolicyAlgorithm): gae_lambda: float = 0.95, clip_range: Union[float, Schedule, None] = None, clip_range_vf: Union[None, float, Schedule] = None, - normalize_advantage: bool = True, + normalize_advantage: bool = False, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: Union[float, None] = None, @@ -97,7 +100,8 @@ class TRPL(BetterOnPolicyAlgorithm): sde_sample_freq: int = -1, use_pca: bool = False, pca_is: bool = False, - projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(), + projection_class: Union[BaseProjectionLayer, str] = BaseProjectionLayer, + projection_kwargs: Optional[Dict[str, Any]] = None, rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, target_kl: Optional[float] = None, @@ -109,6 +113,11 @@ class TRPL(BetterOnPolicyAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): + self.projection_class = castProjection(projection_class) + self.projection_kwargs = projection_kwargs + self.projection = self.projection_class(**self.projection_kwargs) + policy_kwargs['policy_projection'] = self.projection + super().__init__( policy, env, @@ -176,7 +185,6 @@ class TRPL(BetterOnPolicyAlgorithm): clip_range_vf = None self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage - self.projection = castProjection(projection) self.target_kl = target_kl if _init_setup_model: