diff --git a/sbBrix/common/off_policy_algorithm.py b/sbBrix/common/off_policy_algorithm.py index 816d033..0720e33 100644 --- a/sbBrix/common/off_policy_algorithm.py +++ b/sbBrix/common/off_policy_algorithm.py @@ -498,13 +498,13 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise): action_noise = VectorizedActionNoise(action_noise, env.num_envs) - if self.use_sde: + if self.use_sde or self.use_pca: self.actor.reset_noise(env.num_envs) callback.on_rollout_start() continue_training = True while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): - if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0: + if (self.use_sde or self.use_pca) and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.actor.reset_noise(env.num_envs) diff --git a/sbBrix/common/policies.py b/sbBrix/common/policies.py index 67fc963..771ae0f 100644 --- a/sbBrix/common/policies.py +++ b/sbBrix/common/policies.py @@ -548,8 +548,10 @@ class ActorCriticPolicy(BasePolicy): :param n_envs: """ - assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" - self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + if isinstance(self.action_dist, StateDependentNoiseDistribution): + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + else: + self.action_dist.base_noise.reset() def _build_mlp_extractor(self) -> None: """ @@ -887,8 +889,10 @@ class Actor(BasePolicy): :param batch_size: """ msg = "reset_noise() is only available when using gSDE" - assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg - self.action_dist.sample_weights(self.log_std, batch_size=batch_size) + if isinstance(self.action_dist, StateDependentNoiseDistribution): + self.action_dist.sample_weights(self.log_std, batch_size=batch_size) + else: + self.action_dist.base_noise.reset() def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """ @@ -1083,7 +1087,10 @@ class SACPolicy(BasePolicy): :param batch_size: """ - self.actor.reset_noise(batch_size=batch_size) + if isinstance(self.action_space, StateDependentNoiseDistribution): + self.actor.reset_noise(batch_size=batch_size) + else: + self.actor.reset_noise(batch_size=batch_size) def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) diff --git a/sbBrix/ppo/ppo.py b/sbBrix/ppo/ppo.py index 0181975..b3ff5f1 100644 --- a/sbBrix/ppo/ppo.py +++ b/sbBrix/ppo/ppo.py @@ -208,7 +208,7 @@ class PPO(BetterOnPolicyAlgorithm): actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed - if self.use_sde: + if self.use_sde or self.use_pca: self.policy.reset_noise(self.batch_size) values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)