Allow StdNet to predict log_std instead of std

This commit is contained in:
Dominik Moritz Roth 2023-08-22 00:54:26 +02:00
parent fd1d96eea3
commit f93d9adc66

View File

@ -114,9 +114,9 @@ class PCA_Distribution(SB3_Distribution):
self._build_conditioner() self._build_conditioner()
# *Optimizes it anyways* # *Optimizes it anyways*
def proba_distribution_net(self, latent_dim: int): def proba_distribution_net(self, latent_dim: int, return_log_std: bool = False):
mu_net = nn.Linear(latent_dim, self.action_dim) mu_net = nn.Linear(latent_dim, self.action_dim)
std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon) std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon, return_log_std)
return mu_net, std_net return mu_net, std_net
@ -273,7 +273,7 @@ class PCA_Distribution(SB3_Distribution):
class StdNet(nn.Module): class StdNet(nn.Module):
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float): def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std):
super().__init__() super().__init__()
self.action_dim = action_dim self.action_dim = action_dim
self.latent_dim = latent_dim self.latent_dim = latent_dim
@ -282,6 +282,9 @@ class StdNet(nn.Module):
self.enforce_positive_type = EnforcePositiveType.SOFTPLUS self.enforce_positive_type = EnforcePositiveType.SOFTPLUS
self.epsilon = epsilon self.epsilon = epsilon
self.return_log_std = return_log_std
if return_log_std:
self.enforce_positive_type = EnforcePositiveType.NONE
if self.par_strength == Par_Strength.SCALAR: if self.par_strength == Par_Strength.SCALAR:
self.param = nn.Parameter( self.param = nn.Parameter(