Unify how init_std is passed into policy
This commit is contained in:
parent
5fa351db22
commit
6e79fce9ae
@ -11,6 +11,7 @@ import numpy as np
|
|||||||
import torch as th
|
import torch as th
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
from stable_baselines3.common.distributions import (
|
from stable_baselines3.common.distributions import (
|
||||||
BernoulliDistribution,
|
BernoulliDistribution,
|
||||||
@ -514,6 +515,11 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
"learn_features": False,
|
"learn_features": False,
|
||||||
}
|
}
|
||||||
dist_kwargs.update(add_dist_kwargs)
|
dist_kwargs.update(add_dist_kwargs)
|
||||||
|
if use_pca:
|
||||||
|
add_dist_kwargs = {
|
||||||
|
"init_std": math.exp(self.log_std_init)
|
||||||
|
}
|
||||||
|
dist_kwargs.update(add_dist_kwargs)
|
||||||
|
|
||||||
self.use_sde = use_sde
|
self.use_sde = use_sde
|
||||||
self.use_pca = use_pca
|
self.use_pca = use_pca
|
||||||
|
Loading…
Reference in New Issue
Block a user