Porting to sb3=2.1.0
This commit is contained in:
parent
3e27ad3766
commit
1c1a909d27
14
metastable_baselines2/__init__.py
Normal file
14
metastable_baselines2/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from sbBrix.ppo import PPO
|
||||
from sbBrix.sac import SAC
|
||||
|
||||
try:
|
||||
import priorConditionedAnnealing as pca
|
||||
except ModuleNotFoundError:
|
||||
class pca():
|
||||
def PCA_Distribution(*args, **kwargs):
|
||||
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
|
||||
|
||||
__all__ = [
|
||||
"PPO",
|
||||
"SAC",
|
||||
]
|
@ -1,5 +1,5 @@
|
||||
from stable_baselines3.common.distributions import *
|
||||
from priorConditionedAnnealing import PCA_Distribution
|
||||
from metastable_baselines2.pca import PCA_Distribution
|
||||
|
||||
|
||||
def _patched_make_proba_distribution(
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.distributions import (
|
||||
@ -34,11 +34,11 @@ from stable_baselines3.common.torch_layers import (
|
||||
|
||||
from stable_baselines3.common.policies import ContinuousCritic
|
||||
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||
|
||||
from .distributions import make_proba_distribution
|
||||
from priorConditionedAnnealing import PCA_Distribution
|
||||
from metastable_baselines2.pca import PCA_Distribution
|
||||
|
||||
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
||||
|
||||
@ -773,10 +773,12 @@ class Actor(BasePolicy):
|
||||
dividing by 255.0 (True by default)
|
||||
"""
|
||||
|
||||
action_space: spaces.Box
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
@ -894,7 +896,7 @@ class Actor(BasePolicy):
|
||||
else:
|
||||
self.action_dist.base_noise.reset()
|
||||
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
"""
|
||||
Get the parameters for the action distribution.
|
||||
|
||||
@ -915,17 +917,17 @@ class Actor(BasePolicy):
|
||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, {}
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# Note: the action is squashed
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# return action and associated log prob
|
||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self(observation, deterministic)
|
||||
|
||||
|
||||
@ -958,10 +960,14 @@ class SACPolicy(BasePolicy):
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
actor: Actor
|
||||
critic: ContinuousCritic
|
||||
critic_target: ContinuousCritic
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
action_space: spaces.Box,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
@ -1049,7 +1055,7 @@ class SACPolicy(BasePolicy):
|
||||
# Create a separate features extractor for the critic
|
||||
# this requires more memory and computation
|
||||
self.critic = self.make_critic(features_extractor=None)
|
||||
critic_parameters = self.critic.parameters()
|
||||
critic_parameters = list(self.critic.parameters())
|
||||
|
||||
# Critic target should not share the features extractor with critic
|
||||
self.critic_target = self.make_critic(features_extractor=None)
|
||||
@ -1100,10 +1106,10 @@ class SACPolicy(BasePolicy):
|
||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
@ -1117,3 +1123,5 @@ class SACPolicy(BasePolicy):
|
||||
self.actor.set_training_mode(mode)
|
||||
self.critic.set_training_mode(mode)
|
||||
self.training = mode
|
||||
|
||||
SACMlpPolicy = SACPolicy
|
5
metastable_baselines2/ppo/__init__.py
Normal file
5
metastable_baselines2/ppo/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from metastable_baselines2.ppo.ppo import PPO
|
||||
|
||||
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"]
|
@ -1,14 +1,13 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from ..common.on_policy_algorithm import BetterOnPolicyAlgorithm
|
||||
# from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||
from ..common.policies import ActorCriticPolicy, BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
@ -54,6 +53,9 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_pca: Wether to use Prior Conditioned Annealing
|
||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
||||
:param target_kl: Limit the KL divergence between updates,
|
||||
because the clipping is not enough to prevent large update
|
||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
@ -70,7 +72,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||
"MlpPolicy": ActorCriticPolicy
|
||||
}
|
||||
|
||||
@ -93,6 +95,8 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_pca: bool = False,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
@ -115,6 +119,8 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
use_pca=use_pca,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
@ -130,7 +136,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
),
|
||||
)
|
||||
|
||||
print('[i] Using sbBrix version of PPO')
|
||||
print('[i] Using metastable version of PPO')
|
||||
|
||||
# Sanity check, otherwise it will lead to noisy gradient and NaN
|
||||
# because of the advantage normalization
|
1
metastable_baselines2/sac/__init__.py
Normal file
1
metastable_baselines2/sac/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from metastable_baselines2.sac.sac import SAC
|
@ -1,18 +1,18 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
# from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from ..common.off_policy_algorithm import BetterOffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from ..common.policies import SACPolicy
|
||||
from ..common.policies import SACMlpPolicy, SACPolicy, Actor #, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
||||
|
||||
@ -78,11 +78,16 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": SACPolicy,
|
||||
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||
"MlpPolicy": SACMlpPolicy,
|
||||
"SACPolicy": SACPolicy,
|
||||
}
|
||||
|
||||
policy: SACPolicy
|
||||
actor: Actor
|
||||
critic: ContinuousCritic
|
||||
critic_target: ContinuousCritic
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[SACPolicy]],
|
||||
@ -139,11 +144,11 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
use_sde_at_warmup=use_sde_at_warmup,
|
||||
use_pca=use_pca,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(spaces.Box),
|
||||
supported_action_spaces=(spaces.Box,),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
print('[i] Using sbBrix version of SAC')
|
||||
print('[i] Using metastable version of SAC')
|
||||
|
||||
self.target_entropy = target_entropy
|
||||
self.log_ent_coef = None # type: Optional[th.Tensor]
|
||||
@ -151,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
# Inverse of the reward scale
|
||||
self.ent_coef = ent_coef
|
||||
self.target_update_interval = target_update_interval
|
||||
self.ent_coef_optimizer = None
|
||||
self.ent_coef_optimizer = Optional[th.optim.Adam] = None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
@ -165,7 +170,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == "auto":
|
||||
# automatically set target entropy if needed
|
||||
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
|
||||
self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore
|
||||
else:
|
||||
# Force conversion
|
||||
# this will also throw an error for unexpected string
|
||||
@ -212,7 +217,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||
|
||||
# We need to sample because `log_std` may have changed between two gradient steps
|
||||
if self.use_sde:
|
||||
@ -223,7 +228,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
ent_coef_loss = None
|
||||
if self.ent_coef_optimizer is not None:
|
||||
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||
# Important: detach the variable from the graph
|
||||
# so we don't change it with other losses
|
||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||
@ -259,7 +264,8 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
|
||||
critic_losses.append(critic_loss.item())
|
||||
assert isinstance(critic_loss, th.Tensor) # for type checker
|
||||
critic_losses.append(critic_loss.item()) # type: ignore[union-attr]
|
||||
|
||||
# Optimize the critic
|
||||
self.critic.optimizer.zero_grad()
|
2
metastable_baselines2/trpl/__init__.py
Normal file
2
metastable_baselines2/trpl/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from metastable_baselines2.trpl.trpl import TRPL
|
1
metastable_baselines2/trpl/trpl.py
Normal file
1
metastable_baselines2/trpl/trpl.py
Normal file
@ -0,0 +1 @@
|
||||
pass
|
@ -1,7 +0,0 @@
|
||||
from sbBrix.ppo import PPO
|
||||
from sbBrix.sac import SAC
|
||||
|
||||
__all__ = [
|
||||
"PPO",
|
||||
"SAC",
|
||||
]
|
@ -1 +0,0 @@
|
||||
from sbBrix.ppo.ppo import PPO
|
@ -1 +0,0 @@
|
||||
from sbBrix.sac.sac import SAC
|
6
setup.py
6
setup.py
@ -1,12 +1,12 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='sbBrix',
|
||||
version='1.0.0',
|
||||
name='metastable_baselines2',
|
||||
version='2.1.0.0',
|
||||
# url='https://github.com/mypackage.git',
|
||||
# author='Author Name',
|
||||
# author_email='author@gmail.com',
|
||||
# description='Description of my package',
|
||||
packages=['.'],
|
||||
install_requires=['gym', 'stable_baselines3==1.8.0'],
|
||||
install_requires=['gymnasium', 'stable_baselines3==2.1.0'],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user