Use ssf also for pca
This commit is contained in:
parent
f0cd88365e
commit
3e27ad3766
@ -498,13 +498,13 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise):
|
||||
action_noise = VectorizedActionNoise(action_noise, env.num_envs)
|
||||
|
||||
if self.use_sde:
|
||||
if self.use_sde or self.use_pca:
|
||||
self.actor.reset_noise(env.num_envs)
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
||||
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
|
||||
if (self.use_sde or self.use_pca) and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise(env.num_envs)
|
||||
|
||||
|
@ -548,8 +548,10 @@ class ActorCriticPolicy(BasePolicy):
|
||||
|
||||
:param n_envs:
|
||||
"""
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
if isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
else:
|
||||
self.action_dist.base_noise.reset()
|
||||
|
||||
def _build_mlp_extractor(self) -> None:
|
||||
"""
|
||||
@ -887,8 +889,10 @@ class Actor(BasePolicy):
|
||||
:param batch_size:
|
||||
"""
|
||||
msg = "reset_noise() is only available when using gSDE"
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
if isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
else:
|
||||
self.action_dist.base_noise.reset()
|
||||
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
"""
|
||||
@ -1083,7 +1087,10 @@ class SACPolicy(BasePolicy):
|
||||
|
||||
:param batch_size:
|
||||
"""
|
||||
self.actor.reset_noise(batch_size=batch_size)
|
||||
if isinstance(self.action_space, StateDependentNoiseDistribution):
|
||||
self.actor.reset_noise(batch_size=batch_size)
|
||||
else:
|
||||
self.actor.reset_noise(batch_size=batch_size)
|
||||
|
||||
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
||||
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
||||
|
@ -208,7 +208,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
actions = rollout_data.actions.long().flatten()
|
||||
|
||||
# Re-sample the noise matrix because the log_std has changed
|
||||
if self.use_sde:
|
||||
if self.use_sde or self.use_pca:
|
||||
self.policy.reset_noise(self.batch_size)
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||
|
Loading…
Reference in New Issue
Block a user