Unify how init_std is passed into policy
This commit is contained in:
parent
5fa351db22
commit
6e79fce9ae
@ -11,6 +11,7 @@ import numpy as np
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
from stable_baselines3.common.distributions import (
|
||||
BernoulliDistribution,
|
||||
@ -514,6 +515,11 @@ class ActorCriticPolicy(BasePolicy):
|
||||
"learn_features": False,
|
||||
}
|
||||
dist_kwargs.update(add_dist_kwargs)
|
||||
if use_pca:
|
||||
add_dist_kwargs = {
|
||||
"init_std": math.exp(self.log_std_init)
|
||||
}
|
||||
dist_kwargs.update(add_dist_kwargs)
|
||||
|
||||
self.use_sde = use_sde
|
||||
self.use_pca = use_pca
|
||||
|
Loading…
Reference in New Issue
Block a user