diff --git a/sbBrix/common/policies.py b/sbBrix/common/policies.py index f8f3f4e..67fc963 100644 --- a/sbBrix/common/policies.py +++ b/sbBrix/common/policies.py @@ -451,6 +451,11 @@ class ActorCriticPolicy(BasePolicy): if optimizer_class == th.optim.Adam: optimizer_kwargs["eps"] = 1e-5 + if activation_fn == 'ReLU': + activation_fn = nn.ReLU + elif activation_fn == 'tanh': + activation_fn = nn.Tanh + super().__init__( observation_space, action_space, @@ -791,7 +796,11 @@ class Actor(BasePolicy): squash_output=True, ) - # Save arguments to re-create object at loading + if activation_fn == 'ReLU': + activation_fn = nn.ReLU + elif activation_fn == 'tanh': + activation_fn = nn.Tanh + self.use_sde = use_sde self.use_pca = use_pca self.sde_features_extractor = None @@ -977,6 +986,11 @@ class SACPolicy(BasePolicy): normalize_images=normalize_images, ) + if activation_fn == 'ReLU': + activation_fn = nn.ReLU + elif activation_fn == 'tanh': + activation_fn = nn.Tanh + if net_arch is None: net_arch = [256, 256]