Unify how init_std is passed into policy

This commit is contained in:
Dominik Moritz Roth 2024-01-29 18:11:33 +01:00
parent 5fa351db22
commit 6e79fce9ae

View File

@ -11,6 +11,7 @@ import numpy as np
import torch as th import torch as th
from gymnasium import spaces from gymnasium import spaces
from torch import nn from torch import nn
import math
from stable_baselines3.common.distributions import ( from stable_baselines3.common.distributions import (
BernoulliDistribution, BernoulliDistribution,
@ -514,6 +515,11 @@ class ActorCriticPolicy(BasePolicy):
"learn_features": False, "learn_features": False,
} }
dist_kwargs.update(add_dist_kwargs) dist_kwargs.update(add_dist_kwargs)
if use_pca:
add_dist_kwargs = {
"init_std": math.exp(self.log_std_init)
}
dist_kwargs.update(add_dist_kwargs)
self.use_sde = use_sde self.use_sde = use_sde
self.use_pca = use_pca self.use_pca = use_pca