Allow StdNet to predict log_std instead of std
This commit is contained in:
parent
fd1d96eea3
commit
f93d9adc66
@ -114,9 +114,9 @@ class PCA_Distribution(SB3_Distribution):
|
||||
self._build_conditioner()
|
||||
# *Optimizes it anyways*
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int):
|
||||
def proba_distribution_net(self, latent_dim: int, return_log_std: bool = False):
|
||||
mu_net = nn.Linear(latent_dim, self.action_dim)
|
||||
std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon)
|
||||
std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon, return_log_std)
|
||||
|
||||
return mu_net, std_net
|
||||
|
||||
@ -273,7 +273,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
|
||||
|
||||
class StdNet(nn.Module):
|
||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float):
|
||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.latent_dim = latent_dim
|
||||
@ -282,6 +282,9 @@ class StdNet(nn.Module):
|
||||
self.enforce_positive_type = EnforcePositiveType.SOFTPLUS
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.return_log_std = return_log_std
|
||||
if return_log_std:
|
||||
self.enforce_positive_type = EnforcePositiveType.NONE
|
||||
|
||||
if self.par_strength == Par_Strength.SCALAR:
|
||||
self.param = nn.Parameter(
|
||||
|
Loading…
Reference in New Issue
Block a user