Fixes for SACPolicy
This commit is contained in:
parent
28a518fe9d
commit
e6071a546b
@ -12,7 +12,7 @@ from ..common.off_policy_algorithm import BetterOffPolicyAlgorithm
|
|||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||||
from ..common.policies import MlpPolicy, SACPolicy
|
from ..common.policies import SACPolicy
|
||||||
|
|
||||||
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
||||||
|
|
||||||
@ -79,9 +79,8 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
"MlpPolicy": MlpPolicy,
|
"MlpPolicy": SACPolicy,
|
||||||
"CnnPolicy": CnnPolicy,
|
"SACPolicy": SACPolicy,
|
||||||
"MultiInputPolicy": MultiInputPolicy,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
Reference in New Issue
Block a user