diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 4de69f9..0885953 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -88,7 +88,7 @@ class ActorCriticPolicy(BasePolicy): activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, - log_std_init: float = 0.0, + std_init: float = 1.0, full_std: bool = True, sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, @@ -135,7 +135,7 @@ class ActorCriticPolicy(BasePolicy): self.features_dim = self.features_extractor.features_dim self.normalize_images = normalize_images - self.log_std_init = log_std_init + self.log_std_init = math.log(std_init) # Keyword arguments for gSDE distribution if dist_kwargs == None: dist_kwargs = {} diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index 435a5c8..a01d64c 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -265,7 +265,7 @@ class SACPolicy(BasePolicy): net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, - log_std_init: float = -3, + std_init: float = 0.05, sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, clip_mean: float = 2.0, @@ -314,7 +314,7 @@ class SACPolicy(BasePolicy): sde_kwargs = { "use_sde": use_sde, - "log_std_init": log_std_init, + "log_std_init": math.log(std_init), "use_expln": use_expln, "clip_mean": clip_mean, }