Unify how init_std is passed into policy

This commit is contained in:
Dominik Moritz Roth 2024-01-29 18:11:33 +01:00
parent 5fa351db22
commit 6e79fce9ae

View File

@ -11,6 +11,7 @@ import numpy as np
import torch as th
from gymnasium import spaces
from torch import nn
import math
from stable_baselines3.common.distributions import (
BernoulliDistribution,
@ -514,6 +515,11 @@ class ActorCriticPolicy(BasePolicy):
"learn_features": False,
}
dist_kwargs.update(add_dist_kwargs)
if use_pca:
add_dist_kwargs = {
"init_std": math.exp(self.log_std_init)
}
dist_kwargs.update(add_dist_kwargs)
self.use_sde = use_sde
self.use_pca = use_pca