Allow StdNet to predict log_std instead of std
This commit is contained in:
parent
fd1d96eea3
commit
f93d9adc66
@ -114,9 +114,9 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
self._build_conditioner()
|
self._build_conditioner()
|
||||||
# *Optimizes it anyways*
|
# *Optimizes it anyways*
|
||||||
|
|
||||||
def proba_distribution_net(self, latent_dim: int):
|
def proba_distribution_net(self, latent_dim: int, return_log_std: bool = False):
|
||||||
mu_net = nn.Linear(latent_dim, self.action_dim)
|
mu_net = nn.Linear(latent_dim, self.action_dim)
|
||||||
std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon)
|
std_net = StdNet(latent_dim, self.action_dim, self.init_std, self.par_strength, self.epsilon, return_log_std)
|
||||||
|
|
||||||
return mu_net, std_net
|
return mu_net, std_net
|
||||||
|
|
||||||
@ -273,7 +273,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
|
|
||||||
|
|
||||||
class StdNet(nn.Module):
|
class StdNet(nn.Module):
|
||||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float):
|
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.latent_dim = latent_dim
|
self.latent_dim = latent_dim
|
||||||
@ -282,6 +282,9 @@ class StdNet(nn.Module):
|
|||||||
self.enforce_positive_type = EnforcePositiveType.SOFTPLUS
|
self.enforce_positive_type = EnforcePositiveType.SOFTPLUS
|
||||||
|
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
self.return_log_std = return_log_std
|
||||||
|
if return_log_std:
|
||||||
|
self.enforce_positive_type = EnforcePositiveType.NONE
|
||||||
|
|
||||||
if self.par_strength == Par_Strength.SCALAR:
|
if self.par_strength == Par_Strength.SCALAR:
|
||||||
self.param = nn.Parameter(
|
self.param = nn.Parameter(
|
||||||
|
Loading…
Reference in New Issue
Block a user