Bug Fix: Logging of std for PCA based SAC
This commit is contained in:
parent
e39f1573cf
commit
3f3912eed1
@ -678,6 +678,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
||||||
elif isinstance(self.action_dist, PCA_Distribution):
|
elif isinstance(self.action_dist, PCA_Distribution):
|
||||||
std_actions = self.std_net(latent_pi)
|
std_actions = self.std_net(latent_pi)
|
||||||
|
self.log_std = th.log(std_actions)
|
||||||
return self.action_dist.proba_distribution(mean_actions, std_actions)
|
return self.action_dist.proba_distribution(mean_actions, std_actions)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action distribution")
|
raise ValueError("Invalid action distribution")
|
||||||
@ -861,9 +862,10 @@ class Actor(BasePolicy):
|
|||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
msg = "get_std() is only available when using gSDE"
|
if isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
return self.action_dist.get_std(self.log_std)
|
||||||
return self.action_dist.get_std(self.log_std)
|
else:
|
||||||
|
return th.exp(self.log_std)
|
||||||
|
|
||||||
def reset_noise(self, batch_size: int = 1) -> None:
|
def reset_noise(self, batch_size: int = 1) -> None:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user