Use ssf also for pca

This commit is contained in:
Dominik Moritz Roth 2023-09-07 21:07:46 +02:00
parent f0cd88365e
commit 3e27ad3766
3 changed files with 15 additions and 8 deletions

View File

@ -498,13 +498,13 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise):
action_noise = VectorizedActionNoise(action_noise, env.num_envs)
if self.use_sde:
if self.use_sde or self.use_pca:
self.actor.reset_noise(env.num_envs)
callback.on_rollout_start()
continue_training = True
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
if (self.use_sde or self.use_pca) and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.actor.reset_noise(env.num_envs)

View File

@ -548,8 +548,10 @@ class ActorCriticPolicy(BasePolicy):
:param n_envs:
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
if isinstance(self.action_dist, StateDependentNoiseDistribution):
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
else:
self.action_dist.base_noise.reset()
def _build_mlp_extractor(self) -> None:
"""
@ -887,8 +889,10 @@ class Actor(BasePolicy):
:param batch_size:
"""
msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
if isinstance(self.action_dist, StateDependentNoiseDistribution):
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
else:
self.action_dist.base_noise.reset()
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
"""
@ -1083,7 +1087,10 @@ class SACPolicy(BasePolicy):
:param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
if isinstance(self.action_space, StateDependentNoiseDistribution):
self.actor.reset_noise(batch_size=batch_size)
else:
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)

View File

@ -208,7 +208,7 @@ class PPO(BetterOnPolicyAlgorithm):
actions = rollout_data.actions.long().flatten()
# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
if self.use_sde or self.use_pca:
self.policy.reset_noise(self.batch_size)
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)