Fix: Buffer supplying weird shapes for mean and cov
This commit is contained in:
parent
53505a80ad
commit
e788e9f998
@ -29,7 +29,7 @@ class BetterRolloutBufferSamples(NamedTuple):
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
mean: th.Tensor
|
||||
cov_Decomp: th.Tensor
|
||||
cov_decomp: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
|
||||
@ -40,7 +40,7 @@ class BetterDictRolloutBufferSamples(NamedTuple):
|
||||
old_values: th.Tensor
|
||||
old_log_prob: th.Tensor
|
||||
mean: th.Tensor
|
||||
cov_Decomp: th.Tensor
|
||||
cov_decomp: th.Tensor
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
|
||||
@ -227,8 +227,8 @@ class BetterRolloutBuffer(RolloutBuffer):
|
||||
self.actions[batch_inds],
|
||||
self.values[batch_inds].flatten(),
|
||||
self.log_probs[batch_inds].flatten(),
|
||||
self.means[batch_inds].flatten(),
|
||||
self.cov_decomps[batch_inds].flatten(),
|
||||
np.squeeze(self.means[batch_inds], axis=1),
|
||||
np.squeeze(self.cov_decomps[batch_inds], axis=1),
|
||||
self.advantages[batch_inds].flatten(),
|
||||
self.returns[batch_inds].flatten(),
|
||||
)
|
||||
|
@ -39,28 +39,31 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
:param n_epochs: Number of epoch when optimizing the surrogate loss
|
||||
:param gamma: Discount factor
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
:param clip_range: Clipping parameter, it can be a function of the current progress
|
||||
remaining (from 1 to 0).
|
||||
:param clip_range_vf: Clipping parameter for the value function,
|
||||
it can be a function of the current progress remaining (from 1 to 0).
|
||||
:param clip_range: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO.
|
||||
Clipping parameter, it can be a function of the current progress remaining (from 1 to 0).
|
||||
Setting it to None will result in no clipping (default)
|
||||
:param clip_range_vf: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO.
|
||||
Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0).
|
||||
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||
no clipping will be done on the value function.
|
||||
IMPORTANT: this clipping depends on the reward scaling.
|
||||
:param normalize_advantage: Whether to normalize or not the advantage
|
||||
:param normalize_advantage: Normally a good idea; but TRPL actually often works better without normalization of the advantage.
|
||||
Whether to normalize or not the advantage. (Default: False)
|
||||
:param ent_coef: Entropy coefficient for the loss calculation
|
||||
:param vf_coef: Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: The maximum value for the gradient clipping
|
||||
:param max_grad_norm: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO..
|
||||
The maximum value for the gradient clipping. Setting it to None will result in no gradient clipping (default)
|
||||
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_pca: Wether to use Prior Conditioned Annealing
|
||||
:param use_pca: Wether to use Prior Conditioned Annealing.
|
||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
||||
:param target_kl: Limit the KL divergence between updates,
|
||||
because the clipping is not enough to prevent large update
|
||||
:param target_kl: Not part of reference implementation of TRPL, but we still ported it over from sb3's PPO.
|
||||
Default to None: No limit.
|
||||
Limit the KL divergence between updates, because the clipping is not enough to prevent large update
|
||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
By default, there is no limit on the kl div.
|
||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||
the reported success rate, mean episode length, and mean reward over
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
@ -89,7 +92,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
gae_lambda: float = 0.95,
|
||||
clip_range: Union[float, Schedule, None] = None,
|
||||
clip_range_vf: Union[None, float, Schedule] = None,
|
||||
normalize_advantage: bool = True,
|
||||
normalize_advantage: bool = False,
|
||||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: Union[float, None] = None,
|
||||
@ -97,7 +100,8 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
sde_sample_freq: int = -1,
|
||||
use_pca: bool = False,
|
||||
pca_is: bool = False,
|
||||
projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(),
|
||||
projection_class: Union[BaseProjectionLayer, str] = BaseProjectionLayer,
|
||||
projection_kwargs: Optional[Dict[str, Any]] = None,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
@ -109,6 +113,11 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
self.projection_class = castProjection(projection_class)
|
||||
self.projection_kwargs = projection_kwargs
|
||||
self.projection = self.projection_class(**self.projection_kwargs)
|
||||
policy_kwargs['policy_projection'] = self.projection
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
@ -176,7 +185,6 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
clip_range_vf = None
|
||||
self.clip_range_vf = clip_range_vf
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.projection = castProjection(projection)
|
||||
self.target_kl = target_kl
|
||||
|
||||
if _init_setup_model:
|
||||
|
Loading…
Reference in New Issue
Block a user