diff --git a/metastable_baselines2/__init__.py b/metastable_baselines2/__init__.py new file mode 100644 index 0000000..19bfa76 --- /dev/null +++ b/metastable_baselines2/__init__.py @@ -0,0 +1,14 @@ +from sbBrix.ppo import PPO +from sbBrix.sac import SAC + +try: + import priorConditionedAnnealing as pca +except ModuleNotFoundError: + class pca(): + def PCA_Distribution(*args, **kwargs): + raise Exception('PCA is not installed; cannot initialize PCA_Distribution.') + +__all__ = [ + "PPO", + "SAC", +] diff --git a/sbBrix/common/distributions.py b/metastable_baselines2/common/distributions.py similarity index 97% rename from sbBrix/common/distributions.py rename to metastable_baselines2/common/distributions.py index 9e48ac4..4daa98e 100644 --- a/sbBrix/common/distributions.py +++ b/metastable_baselines2/common/distributions.py @@ -1,5 +1,5 @@ from stable_baselines3.common.distributions import * -from priorConditionedAnnealing import PCA_Distribution +from metastable_baselines2.pca import PCA_Distribution def _patched_make_proba_distribution( diff --git a/sbBrix/common/off_policy_algorithm.py b/metastable_baselines2/common/off_policy_algorithm.py similarity index 100% rename from sbBrix/common/off_policy_algorithm.py rename to metastable_baselines2/common/off_policy_algorithm.py diff --git a/sbBrix/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py similarity index 100% rename from sbBrix/common/on_policy_algorithm.py rename to metastable_baselines2/common/on_policy_algorithm.py diff --git a/sbBrix/common/policies.py b/metastable_baselines2/common/policies.py similarity index 97% rename from sbBrix/common/policies.py rename to metastable_baselines2/common/policies.py index 771ae0f..0867f18 100644 --- a/sbBrix/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import ( @@ -34,11 +34,11 @@ from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.policies import ContinuousCritic -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor from .distributions import make_proba_distribution -from priorConditionedAnnealing import PCA_Distribution +from metastable_baselines2.pca import PCA_Distribution SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel") @@ -773,10 +773,12 @@ class Actor(BasePolicy): dividing by 255.0 (True by default) """ + action_space: spaces.Box + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -894,7 +896,7 @@ class Actor(BasePolicy): else: self.action_dist.base_noise.reset() - def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """ Get the parameters for the action distribution. @@ -915,17 +917,17 @@ class Actor(BasePolicy): log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} - def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) - def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self(observation, deterministic) @@ -958,10 +960,14 @@ class SACPolicy(BasePolicy): between the actor and the critic (this saves computation time) """ + actor: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -1049,7 +1055,7 @@ class SACPolicy(BasePolicy): # Create a separate features extractor for the critic # this requires more memory and computation self.critic = self.make_critic(features_extractor=None) - critic_parameters = self.critic.parameters() + critic_parameters = list(self.critic.parameters()) # Critic target should not share the features extractor with critic self.critic_target = self.make_critic(features_extractor=None) @@ -1100,10 +1106,10 @@ class SACPolicy(BasePolicy): critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) return ContinuousCritic(**critic_kwargs).to(self.device) - def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self.actor(observation, deterministic) def set_training_mode(self, mode: bool) -> None: @@ -1117,3 +1123,5 @@ class SACPolicy(BasePolicy): self.actor.set_training_mode(mode) self.critic.set_training_mode(mode) self.training = mode + +SACMlpPolicy = SACPolicy \ No newline at end of file diff --git a/metastable_baselines2/ppo/__init__.py b/metastable_baselines2/ppo/__init__.py new file mode 100644 index 0000000..7debaf3 --- /dev/null +++ b/metastable_baselines2/ppo/__init__.py @@ -0,0 +1,5 @@ +from metastable_baselines2.ppo.ppo import PPO + +from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"] \ No newline at end of file diff --git a/sbBrix/ppo/ppo.py b/metastable_baselines2/ppo/ppo.py similarity index 94% rename from sbBrix/ppo/ppo.py rename to metastable_baselines2/ppo/ppo.py index b3ff5f1..732f945 100644 --- a/sbBrix/ppo/ppo.py +++ b/metastable_baselines2/ppo/ppo.py @@ -1,14 +1,13 @@ import warnings -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F -# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.buffers import RolloutBuffer from ..common.on_policy_algorithm import BetterOnPolicyAlgorithm -# from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from ..common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -54,6 +53,9 @@ class PPO(BetterOnPolicyAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param use_pca: Wether to use Prior Conditioned Annealing + :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. + :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) @@ -70,7 +72,7 @@ class PPO(BetterOnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: Dict[str, Type[BasePolicy]] = { + policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy } @@ -93,6 +95,8 @@ class PPO(BetterOnPolicyAlgorithm): use_sde: bool = False, sde_sample_freq: int = -1, use_pca: bool = False, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, target_kl: Optional[float] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, @@ -115,6 +119,8 @@ class PPO(BetterOnPolicyAlgorithm): use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_pca=use_pca, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, @@ -130,7 +136,7 @@ class PPO(BetterOnPolicyAlgorithm): ), ) - print('[i] Using sbBrix version of PPO') + print('[i] Using metastable version of PPO') # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization diff --git a/metastable_baselines2/sac/__init__.py b/metastable_baselines2/sac/__init__.py new file mode 100644 index 0000000..34eae34 --- /dev/null +++ b/metastable_baselines2/sac/__init__.py @@ -0,0 +1 @@ +from metastable_baselines2.sac.sac import SAC diff --git a/sbBrix/sac/sac.py b/metastable_baselines2/sac/sac.py similarity index 93% rename from sbBrix/sac/sac.py rename to metastable_baselines2/sac/sac.py index 29ffdfb..4db8c62 100644 --- a/sbBrix/sac/sac.py +++ b/metastable_baselines2/sac/sac.py @@ -1,18 +1,18 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise # from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from ..common.off_policy_algorithm import BetterOffPolicyAlgorithm -from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update -from ..common.policies import SACPolicy +from ..common.policies import SACMlpPolicy, SACPolicy, Actor #, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy SelfSAC = TypeVar("SelfSAC", bound="SAC") @@ -78,11 +78,16 @@ class SAC(BetterOffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: Dict[str, Type[BasePolicy]] = { - "MlpPolicy": SACPolicy, + policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + "MlpPolicy": SACMlpPolicy, "SACPolicy": SACPolicy, } + policy: SACPolicy + actor: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic + def __init__( self, policy: Union[str, Type[SACPolicy]], @@ -139,11 +144,11 @@ class SAC(BetterOffPolicyAlgorithm): use_sde_at_warmup=use_sde_at_warmup, use_pca=use_pca, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) - print('[i] Using sbBrix version of SAC') + print('[i] Using metastable version of SAC') self.target_entropy = target_entropy self.log_ent_coef = None # type: Optional[th.Tensor] @@ -151,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm): # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval - self.ent_coef_optimizer = None + self.ent_coef_optimizer = Optional[th.optim.Adam] = None if _init_setup_model: self._setup_model() @@ -165,7 +170,7 @@ class SAC(BetterOffPolicyAlgorithm): # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed - self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) + self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore else: # Force conversion # this will also throw an error for unexpected string @@ -212,7 +217,7 @@ class SAC(BetterOffPolicyAlgorithm): for gradient_step in range(gradient_steps): # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: @@ -223,7 +228,7 @@ class SAC(BetterOffPolicyAlgorithm): log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None - if self.ent_coef_optimizer is not None: + if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 @@ -259,7 +264,8 @@ class SAC(BetterOffPolicyAlgorithm): # Compute critic loss critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) - critic_losses.append(critic_loss.item()) + assert isinstance(critic_loss, th.Tensor) # for type checker + critic_losses.append(critic_loss.item()) # type: ignore[union-attr] # Optimize the critic self.critic.optimizer.zero_grad() diff --git a/metastable_baselines2/trpl/__init__.py b/metastable_baselines2/trpl/__init__.py new file mode 100644 index 0000000..bd8fa51 --- /dev/null +++ b/metastable_baselines2/trpl/__init__.py @@ -0,0 +1,2 @@ +from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from metastable_baselines2.trpl.trpl import TRPL diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py new file mode 100644 index 0000000..fc80254 --- /dev/null +++ b/metastable_baselines2/trpl/trpl.py @@ -0,0 +1 @@ +pass \ No newline at end of file diff --git a/sbBrix/__init__.py b/sbBrix/__init__.py deleted file mode 100644 index d52b3da..0000000 --- a/sbBrix/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from sbBrix.ppo import PPO -from sbBrix.sac import SAC - -__all__ = [ - "PPO", - "SAC", -] diff --git a/sbBrix/ppo/__init__.py b/sbBrix/ppo/__init__.py deleted file mode 100644 index 86a1105..0000000 --- a/sbBrix/ppo/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from sbBrix.ppo.ppo import PPO diff --git a/sbBrix/sac/__init__.py b/sbBrix/sac/__init__.py deleted file mode 100644 index 59a45df..0000000 --- a/sbBrix/sac/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from sbBrix.sac.sac import SAC diff --git a/setup.py b/setup.py index 96c9730..3cf698a 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,12 @@ from setuptools import setup, find_packages setup( - name='sbBrix', - version='1.0.0', + name='metastable_baselines2', + version='2.1.0.0', # url='https://github.com/mypackage.git', # author='Author Name', # author_email='author@gmail.com', # description='Description of my package', packages=['.'], - install_requires=['gym', 'stable_baselines3==1.8.0'], + install_requires=['gymnasium', 'stable_baselines3==2.1.0'], )