Bug Fix: Logging of std for PCA based SAC

This commit is contained in:
Dominik Moritz Roth 2023-08-22 00:30:17 +02:00
parent e39f1573cf
commit 3f3912eed1

View File

@ -678,6 +678,7 @@ class ActorCriticPolicy(BasePolicy):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
elif isinstance(self.action_dist, PCA_Distribution):
std_actions = self.std_net(latent_pi)
self.log_std = th.log(std_actions)
return self.action_dist.proba_distribution(mean_actions, std_actions)
else:
raise ValueError("Invalid action distribution")
@ -861,9 +862,10 @@ class Actor(BasePolicy):
:return:
"""
msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)
if isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.get_std(self.log_std)
else:
return th.exp(self.log_std)
def reset_noise(self, batch_size: int = 1) -> None:
"""