diff --git a/sbBrix/common/policies.py b/sbBrix/common/policies.py index 365f721..61d5e38 100644 --- a/sbBrix/common/policies.py +++ b/sbBrix/common/policies.py @@ -740,6 +740,9 @@ class ActorCriticPolicy(BasePolicy): return self.value_net(latent_vf) +LOG_STD_MIN, LOG_STD_MAX = 0.1, 1000 + + class Actor(BasePolicy): """ Actor network (policy) for SAC. @@ -812,7 +815,7 @@ class Actor(BasePolicy): action_dim, **dist_kwargs ) self.mu, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init, **dist_kwargs + latent_dim=last_layer_dim, return_log_std=True, **dist_kwargs ) # Avoid numerical issues by limiting the mean of the Gaussian # to be in [-clip_mean, clip_mean] @@ -865,7 +868,7 @@ class Actor(BasePolicy): if isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.get_std(self.log_std) else: - return th.exp(self.log_std) + return th.exp(self._remember_log_std) def reset_noise(self, batch_size: int = 1) -> None: """ @@ -893,6 +896,7 @@ class Actor(BasePolicy): return mean_actions, self.log_std, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) log_std = self.log_std(latent_pi) + self._remember_log_std = log_std # Original Implementation to cap the standard deviation log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {}