From e6071a546b255e85c07ad74ce1f8aa24dae76684 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 22 Aug 2023 00:20:42 +0200 Subject: [PATCH] Fixes for SACPolicy --- sbBrix/sac/sac.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sbBrix/sac/sac.py b/sbBrix/sac/sac.py index 0979c1d..29ffdfb 100644 --- a/sbBrix/sac/sac.py +++ b/sbBrix/sac/sac.py @@ -12,7 +12,7 @@ from ..common.off_policy_algorithm import BetterOffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update -from ..common.policies import MlpPolicy, SACPolicy +from ..common.policies import SACPolicy SelfSAC = TypeVar("SelfSAC", bound="SAC") @@ -79,9 +79,8 @@ class SAC(BetterOffPolicyAlgorithm): """ policy_aliases: Dict[str, Type[BasePolicy]] = { - "MlpPolicy": MlpPolicy, - "CnnPolicy": CnnPolicy, - "MultiInputPolicy": MultiInputPolicy, + "MlpPolicy": SACPolicy, + "SACPolicy": SACPolicy, } def __init__(