Fix: Buffer supplying weird shapes for mean and cov

This commit is contained in:
Dominik Moritz Roth 2024-01-26 12:40:24 +01:00
parent 53505a80ad
commit e788e9f998
2 changed files with 25 additions and 17 deletions

View File

@ -29,7 +29,7 @@ class BetterRolloutBufferSamples(NamedTuple):
old_values: th.Tensor
old_log_prob: th.Tensor
mean: th.Tensor
cov_Decomp: th.Tensor
cov_decomp: th.Tensor
advantages: th.Tensor
returns: th.Tensor
@ -40,7 +40,7 @@ class BetterDictRolloutBufferSamples(NamedTuple):
old_values: th.Tensor
old_log_prob: th.Tensor
mean: th.Tensor
cov_Decomp: th.Tensor
cov_decomp: th.Tensor
advantages: th.Tensor
returns: th.Tensor
@ -227,8 +227,8 @@ class BetterRolloutBuffer(RolloutBuffer):
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.means[batch_inds].flatten(),
self.cov_decomps[batch_inds].flatten(),
np.squeeze(self.means[batch_inds], axis=1),
np.squeeze(self.cov_decomps[batch_inds], axis=1),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)

View File

@ -39,28 +39,31 @@ class TRPL(BetterOnPolicyAlgorithm):
:param n_epochs: Number of epoch when optimizing the surrogate loss
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param clip_range: Clipping parameter, it can be a function of the current progress
remaining (from 1 to 0).
:param clip_range_vf: Clipping parameter for the value function,
it can be a function of the current progress remaining (from 1 to 0).
:param clip_range: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO.
Clipping parameter, it can be a function of the current progress remaining (from 1 to 0).
Setting it to None will result in no clipping (default)
:param clip_range_vf: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO.
Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0).
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param normalize_advantage: Whether to normalize or not the advantage
:param normalize_advantage: Normally a good idea; but TRPL actually often works better without normalization of the advantage.
Whether to normalize or not the advantage. (Default: False)
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param max_grad_norm: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO..
The maximum value for the gradient clipping. Setting it to None will result in no gradient clipping (default)
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_pca: Wether to use Prior Conditioned Annealing
:param use_pca: Wether to use Prior Conditioned Annealing.
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
:param target_kl: Not part of reference implementation of TRPL, but we still ported it over from sb3's PPO.
Default to None: No limit.
Limit the KL divergence between updates, because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
@ -89,7 +92,7 @@ class TRPL(BetterOnPolicyAlgorithm):
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule, None] = None,
clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True,
normalize_advantage: bool = False,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: Union[float, None] = None,
@ -97,7 +100,8 @@ class TRPL(BetterOnPolicyAlgorithm):
sde_sample_freq: int = -1,
use_pca: bool = False,
pca_is: bool = False,
projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(),
projection_class: Union[BaseProjectionLayer, str] = BaseProjectionLayer,
projection_kwargs: Optional[Dict[str, Any]] = None,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
@ -109,6 +113,11 @@ class TRPL(BetterOnPolicyAlgorithm):
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
self.projection_class = castProjection(projection_class)
self.projection_kwargs = projection_kwargs
self.projection = self.projection_class(**self.projection_kwargs)
policy_kwargs['policy_projection'] = self.projection
super().__init__(
policy,
env,
@ -176,7 +185,6 @@ class TRPL(BetterOnPolicyAlgorithm):
clip_range_vf = None
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.projection = castProjection(projection)
self.target_kl = target_kl
if _init_setup_model: