diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 9fba2da..937bcb0 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -11,6 +11,7 @@ import numpy as np import torch as th from gymnasium import spaces from torch import nn +import math from stable_baselines3.common.distributions import ( BernoulliDistribution, @@ -514,6 +515,11 @@ class ActorCriticPolicy(BasePolicy): "learn_features": False, } dist_kwargs.update(add_dist_kwargs) + if use_pca: + add_dist_kwargs = { + "init_std": math.exp(self.log_std_init) + } + dist_kwargs.update(add_dist_kwargs) self.use_sde = use_sde self.use_pca = use_pca