Porting to sb3=2.1.0

This commit is contained in:
Dominik Moritz Roth 2023-11-11 12:14:30 +01:00
parent 3e27ad3766
commit 1c1a909d27
15 changed files with 78 additions and 44 deletions

View File

@ -0,0 +1,14 @@
from sbBrix.ppo import PPO
from sbBrix.sac import SAC
try:
import priorConditionedAnnealing as pca
except ModuleNotFoundError:
class pca():
def PCA_Distribution(*args, **kwargs):
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
__all__ = [
"PPO",
"SAC",
]

View File

@ -1,5 +1,5 @@
from stable_baselines3.common.distributions import *
from priorConditionedAnnealing import PCA_Distribution
from metastable_baselines2.pca import PCA_Distribution
def _patched_make_proba_distribution(

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.distributions import (
@ -34,11 +34,11 @@ from stable_baselines3.common.torch_layers import (
from stable_baselines3.common.policies import ContinuousCritic
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
from .distributions import make_proba_distribution
from priorConditionedAnnealing import PCA_Distribution
from metastable_baselines2.pca import PCA_Distribution
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
@ -773,10 +773,12 @@ class Actor(BasePolicy):
dividing by 255.0 (True by default)
"""
action_space: spaces.Box
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
@ -894,7 +896,7 @@ class Actor(BasePolicy):
else:
self.action_dist.base_noise.reset()
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
@ -915,17 +917,17 @@ class Actor(BasePolicy):
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
@ -958,10 +960,14 @@ class SACPolicy(BasePolicy):
between the actor and the critic (this saves computation time)
"""
actor: Actor
critic: ContinuousCritic
critic_target: ContinuousCritic
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -1049,7 +1055,7 @@ class SACPolicy(BasePolicy):
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = self.critic.parameters()
critic_parameters = list(self.critic.parameters())
# Critic target should not share the features extractor with critic
self.critic_target = self.make_critic(features_extractor=None)
@ -1100,10 +1106,10 @@ class SACPolicy(BasePolicy):
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
def set_training_mode(self, mode: bool) -> None:
@ -1117,3 +1123,5 @@ class SACPolicy(BasePolicy):
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode
SACMlpPolicy = SACPolicy

View File

@ -0,0 +1,5 @@
from metastable_baselines2.ppo.ppo import PPO
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"]

View File

@ -1,14 +1,13 @@
import warnings
from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from torch.nn import functional as F
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.buffers import RolloutBuffer
from ..common.on_policy_algorithm import BetterOnPolicyAlgorithm
# from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from ..common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
@ -54,6 +53,9 @@ class PPO(BetterOnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_pca: Wether to use Prior Conditioned Annealing
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
@ -70,7 +72,7 @@ class PPO(BetterOnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy
}
@ -93,6 +95,8 @@ class PPO(BetterOnPolicyAlgorithm):
use_sde: bool = False,
sde_sample_freq: int = -1,
use_pca: bool = False,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
@ -115,6 +119,8 @@ class PPO(BetterOnPolicyAlgorithm):
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_pca=use_pca,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
@ -130,7 +136,7 @@ class PPO(BetterOnPolicyAlgorithm):
),
)
print('[i] Using sbBrix version of PPO')
print('[i] Using metastable version of PPO')
# Sanity check, otherwise it will lead to noisy gradient and NaN
# because of the advantage normalization

View File

@ -0,0 +1 @@
from metastable_baselines2.sac.sac import SAC

View File

@ -1,18 +1,18 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gym import spaces
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
# from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from ..common.off_policy_algorithm import BetterOffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from ..common.policies import SACPolicy
from ..common.policies import SACMlpPolicy, SACPolicy, Actor #, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
SelfSAC = TypeVar("SelfSAC", bound="SAC")
@ -78,11 +78,16 @@ class SAC(BetterOffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": SACPolicy,
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
"MlpPolicy": SACMlpPolicy,
"SACPolicy": SACPolicy,
}
policy: SACPolicy
actor: Actor
critic: ContinuousCritic
critic_target: ContinuousCritic
def __init__(
self,
policy: Union[str, Type[SACPolicy]],
@ -139,11 +144,11 @@ class SAC(BetterOffPolicyAlgorithm):
use_sde_at_warmup=use_sde_at_warmup,
use_pca=use_pca,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(spaces.Box),
supported_action_spaces=(spaces.Box,),
support_multi_env=True,
)
print('[i] Using sbBrix version of SAC')
print('[i] Using metastable version of SAC')
self.target_entropy = target_entropy
self.log_ent_coef = None # type: Optional[th.Tensor]
@ -151,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm):
# Inverse of the reward scale
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = None
self.ent_coef_optimizer = Optional[th.optim.Adam] = None
if _init_setup_model:
self._setup_model()
@ -165,7 +170,7 @@ class SAC(BetterOffPolicyAlgorithm):
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
@ -212,7 +217,7 @@ class SAC(BetterOffPolicyAlgorithm):
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
@ -223,7 +228,7 @@ class SAC(BetterOffPolicyAlgorithm):
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
if self.ent_coef_optimizer is not None:
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
@ -259,7 +264,8 @@ class SAC(BetterOffPolicyAlgorithm):
# Compute critic loss
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
critic_losses.append(critic_loss.item())
assert isinstance(critic_loss, th.Tensor) # for type checker
critic_losses.append(critic_loss.item()) # type: ignore[union-attr]
# Optimize the critic
self.critic.optimizer.zero_grad()

View File

@ -0,0 +1,2 @@
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from metastable_baselines2.trpl.trpl import TRPL

View File

@ -0,0 +1 @@
pass

View File

@ -1,7 +0,0 @@
from sbBrix.ppo import PPO
from sbBrix.sac import SAC
__all__ = [
"PPO",
"SAC",
]

View File

@ -1 +0,0 @@
from sbBrix.ppo.ppo import PPO

View File

@ -1 +0,0 @@
from sbBrix.sac.sac import SAC

View File

@ -1,12 +1,12 @@
from setuptools import setup, find_packages
setup(
name='sbBrix',
version='1.0.0',
name='metastable_baselines2',
version='2.1.0.0',
# url='https://github.com/mypackage.git',
# author='Author Name',
# author_email='author@gmail.com',
# description='Description of my package',
packages=['.'],
install_requires=['gym', 'stable_baselines3==1.8.0'],
install_requires=['gymnasium', 'stable_baselines3==2.1.0'],
)