Support for seting act_fn as string
This commit is contained in:
parent
3c81a15630
commit
f0cd88365e
@ -451,6 +451,11 @@ class ActorCriticPolicy(BasePolicy):
|
||||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs["eps"] = 1e-5
|
||||
|
||||
if activation_fn == 'ReLU':
|
||||
activation_fn = nn.ReLU
|
||||
elif activation_fn == 'tanh':
|
||||
activation_fn = nn.Tanh
|
||||
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
@ -791,7 +796,11 @@ class Actor(BasePolicy):
|
||||
squash_output=True,
|
||||
)
|
||||
|
||||
# Save arguments to re-create object at loading
|
||||
if activation_fn == 'ReLU':
|
||||
activation_fn = nn.ReLU
|
||||
elif activation_fn == 'tanh':
|
||||
activation_fn = nn.Tanh
|
||||
|
||||
self.use_sde = use_sde
|
||||
self.use_pca = use_pca
|
||||
self.sde_features_extractor = None
|
||||
@ -977,6 +986,11 @@ class SACPolicy(BasePolicy):
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
if activation_fn == 'ReLU':
|
||||
activation_fn = nn.ReLU
|
||||
elif activation_fn == 'tanh':
|
||||
activation_fn = nn.Tanh
|
||||
|
||||
if net_arch is None:
|
||||
net_arch = [256, 256]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user