From 6e79fce9aeed969cf5f95f528a8f7bb1abf69d68 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 29 Jan 2024 18:11:33 +0100 Subject: [PATCH] Unify how init_std is passed into policy --- metastable_baselines2/common/policies.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 9fba2da..937bcb0 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -11,6 +11,7 @@ import numpy as np import torch as th from gymnasium import spaces from torch import nn +import math from stable_baselines3.common.distributions import ( BernoulliDistribution, @@ -514,6 +515,11 @@ class ActorCriticPolicy(BasePolicy): "learn_features": False, } dist_kwargs.update(add_dist_kwargs) + if use_pca: + add_dist_kwargs = { + "init_std": math.exp(self.log_std_init) + } + dist_kwargs.update(add_dist_kwargs) self.use_sde = use_sde self.use_pca = use_pca