Fix: Buffer supplying weird shapes for mean and cov
This commit is contained in:
parent
53505a80ad
commit
e788e9f998
@ -29,7 +29,7 @@ class BetterRolloutBufferSamples(NamedTuple):
|
|||||||
old_values: th.Tensor
|
old_values: th.Tensor
|
||||||
old_log_prob: th.Tensor
|
old_log_prob: th.Tensor
|
||||||
mean: th.Tensor
|
mean: th.Tensor
|
||||||
cov_Decomp: th.Tensor
|
cov_decomp: th.Tensor
|
||||||
advantages: th.Tensor
|
advantages: th.Tensor
|
||||||
returns: th.Tensor
|
returns: th.Tensor
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ class BetterDictRolloutBufferSamples(NamedTuple):
|
|||||||
old_values: th.Tensor
|
old_values: th.Tensor
|
||||||
old_log_prob: th.Tensor
|
old_log_prob: th.Tensor
|
||||||
mean: th.Tensor
|
mean: th.Tensor
|
||||||
cov_Decomp: th.Tensor
|
cov_decomp: th.Tensor
|
||||||
advantages: th.Tensor
|
advantages: th.Tensor
|
||||||
returns: th.Tensor
|
returns: th.Tensor
|
||||||
|
|
||||||
@ -227,8 +227,8 @@ class BetterRolloutBuffer(RolloutBuffer):
|
|||||||
self.actions[batch_inds],
|
self.actions[batch_inds],
|
||||||
self.values[batch_inds].flatten(),
|
self.values[batch_inds].flatten(),
|
||||||
self.log_probs[batch_inds].flatten(),
|
self.log_probs[batch_inds].flatten(),
|
||||||
self.means[batch_inds].flatten(),
|
np.squeeze(self.means[batch_inds], axis=1),
|
||||||
self.cov_decomps[batch_inds].flatten(),
|
np.squeeze(self.cov_decomps[batch_inds], axis=1),
|
||||||
self.advantages[batch_inds].flatten(),
|
self.advantages[batch_inds].flatten(),
|
||||||
self.returns[batch_inds].flatten(),
|
self.returns[batch_inds].flatten(),
|
||||||
)
|
)
|
||||||
|
@ -39,28 +39,31 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
:param n_epochs: Number of epoch when optimizing the surrogate loss
|
:param n_epochs: Number of epoch when optimizing the surrogate loss
|
||||||
:param gamma: Discount factor
|
:param gamma: Discount factor
|
||||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||||
:param clip_range: Clipping parameter, it can be a function of the current progress
|
:param clip_range: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO.
|
||||||
remaining (from 1 to 0).
|
Clipping parameter, it can be a function of the current progress remaining (from 1 to 0).
|
||||||
:param clip_range_vf: Clipping parameter for the value function,
|
Setting it to None will result in no clipping (default)
|
||||||
it can be a function of the current progress remaining (from 1 to 0).
|
:param clip_range_vf: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO.
|
||||||
|
Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0).
|
||||||
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||||
no clipping will be done on the value function.
|
no clipping will be done on the value function.
|
||||||
IMPORTANT: this clipping depends on the reward scaling.
|
IMPORTANT: this clipping depends on the reward scaling.
|
||||||
:param normalize_advantage: Whether to normalize or not the advantage
|
:param normalize_advantage: Normally a good idea; but TRPL actually often works better without normalization of the advantage.
|
||||||
|
Whether to normalize or not the advantage. (Default: False)
|
||||||
:param ent_coef: Entropy coefficient for the loss calculation
|
:param ent_coef: Entropy coefficient for the loss calculation
|
||||||
:param vf_coef: Value function coefficient for the loss calculation
|
:param vf_coef: Value function coefficient for the loss calculation
|
||||||
:param max_grad_norm: The maximum value for the gradient clipping
|
:param max_grad_norm: Should not be used for normal TRPL usage. Is only here to bridge the gap to PPO..
|
||||||
|
The maximum value for the gradient clipping. Setting it to None will result in no gradient clipping (default)
|
||||||
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
||||||
instead of action noise exploration (default: False)
|
instead of action noise exploration (default: False)
|
||||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||||
Default: -1 (only sample at the beginning of the rollout)
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
:param use_pca: Wether to use Prior Conditioned Annealing
|
:param use_pca: Wether to use Prior Conditioned Annealing.
|
||||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
||||||
:param target_kl: Limit the KL divergence between updates,
|
:param target_kl: Not part of reference implementation of TRPL, but we still ported it over from sb3's PPO.
|
||||||
because the clipping is not enough to prevent large update
|
Default to None: No limit.
|
||||||
|
Limit the KL divergence between updates, because the clipping is not enough to prevent large update
|
||||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||||
By default, there is no limit on the kl div.
|
|
||||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||||
the reported success rate, mean episode length, and mean reward over
|
the reported success rate, mean episode length, and mean reward over
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
@ -89,7 +92,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
gae_lambda: float = 0.95,
|
gae_lambda: float = 0.95,
|
||||||
clip_range: Union[float, Schedule, None] = None,
|
clip_range: Union[float, Schedule, None] = None,
|
||||||
clip_range_vf: Union[None, float, Schedule] = None,
|
clip_range_vf: Union[None, float, Schedule] = None,
|
||||||
normalize_advantage: bool = True,
|
normalize_advantage: bool = False,
|
||||||
ent_coef: float = 0.0,
|
ent_coef: float = 0.0,
|
||||||
vf_coef: float = 0.5,
|
vf_coef: float = 0.5,
|
||||||
max_grad_norm: Union[float, None] = None,
|
max_grad_norm: Union[float, None] = None,
|
||||||
@ -97,7 +100,8 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
use_pca: bool = False,
|
use_pca: bool = False,
|
||||||
pca_is: bool = False,
|
pca_is: bool = False,
|
||||||
projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(),
|
projection_class: Union[BaseProjectionLayer, str] = BaseProjectionLayer,
|
||||||
|
projection_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
@ -109,6 +113,11 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
|
self.projection_class = castProjection(projection_class)
|
||||||
|
self.projection_kwargs = projection_kwargs
|
||||||
|
self.projection = self.projection_class(**self.projection_kwargs)
|
||||||
|
policy_kwargs['policy_projection'] = self.projection
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy,
|
policy,
|
||||||
env,
|
env,
|
||||||
@ -176,7 +185,6 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
clip_range_vf = None
|
clip_range_vf = None
|
||||||
self.clip_range_vf = clip_range_vf
|
self.clip_range_vf = clip_range_vf
|
||||||
self.normalize_advantage = normalize_advantage
|
self.normalize_advantage = normalize_advantage
|
||||||
self.projection = castProjection(projection)
|
|
||||||
self.target_kl = target_kl
|
self.target_kl = target_kl
|
||||||
|
|
||||||
if _init_setup_model:
|
if _init_setup_model:
|
||||||
|
Loading…
Reference in New Issue
Block a user