Allow StdNet to predict log_std instead of std

This commit is contained in:
Dominik Moritz Roth 2023-08-22 00:54:26 +02:00
parent fd1d96eea3
commit f93d9adc66

View File

@ -114,9 +114,9 @@ class PCA_Distribution(SB3_Distribution):
self._build_conditioner()
# *Optimizes it anyways*
def proba_distribution_net(self, latent_dim: int):
def proba_distribution_net(self, latent_dim: int, return_log_std: bool = False):
mu_net = nn.Linear(latent_dim, self.action_dim)
std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon)
std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon, return_log_std)
return mu_net, std_net
@ -273,7 +273,7 @@ class PCA_Distribution(SB3_Distribution):
class StdNet(nn.Module):
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float):
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std):
super().__init__()
self.action_dim = action_dim
self.latent_dim = latent_dim
@ -282,6 +282,9 @@ class StdNet(nn.Module):
self.enforce_positive_type = EnforcePositiveType.SOFTPLUS
self.epsilon = epsilon
self.return_log_std = return_log_std
if return_log_std:
self.enforce_positive_type = EnforcePositiveType.NONE
if self.par_strength == Par_Strength.SCALAR:
self.param = nn.Parameter(