From 3f3912eed1bf64a899193d940372630864b7a0fe Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 22 Aug 2023 00:30:17 +0200 Subject: [PATCH] Bug Fix: Logging of std for PCA based SAC --- sbBrix/common/policies.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sbBrix/common/policies.py b/sbBrix/common/policies.py index 40d36eb..365f721 100644 --- a/sbBrix/common/policies.py +++ b/sbBrix/common/policies.py @@ -678,6 +678,7 @@ class ActorCriticPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) elif isinstance(self.action_dist, PCA_Distribution): std_actions = self.std_net(latent_pi) + self.log_std = th.log(std_actions) return self.action_dist.proba_distribution(mean_actions, std_actions) else: raise ValueError("Invalid action distribution") @@ -861,9 +862,10 @@ class Actor(BasePolicy): :return: """ - msg = "get_std() is only available when using gSDE" - assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg - return self.action_dist.get_std(self.log_std) + if isinstance(self.action_dist, StateDependentNoiseDistribution): + return self.action_dist.get_std(self.log_std) + else: + return th.exp(self.log_std) def reset_noise(self, batch_size: int = 1) -> None: """