Porting to sb3=2.1.0
This commit is contained in:
parent
3e27ad3766
commit
1c1a909d27
14
metastable_baselines2/__init__.py
Normal file
14
metastable_baselines2/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from sbBrix.ppo import PPO
|
||||||
|
from sbBrix.sac import SAC
|
||||||
|
|
||||||
|
try:
|
||||||
|
import priorConditionedAnnealing as pca
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
class pca():
|
||||||
|
def PCA_Distribution(*args, **kwargs):
|
||||||
|
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PPO",
|
||||||
|
"SAC",
|
||||||
|
]
|
@ -1,5 +1,5 @@
|
|||||||
from stable_baselines3.common.distributions import *
|
from stable_baselines3.common.distributions import *
|
||||||
from priorConditionedAnnealing import PCA_Distribution
|
from metastable_baselines2.pca import PCA_Distribution
|
||||||
|
|
||||||
|
|
||||||
def _patched_make_proba_distribution(
|
def _patched_make_proba_distribution(
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from stable_baselines3.common.distributions import (
|
from stable_baselines3.common.distributions import (
|
||||||
@ -34,11 +34,11 @@ from stable_baselines3.common.torch_layers import (
|
|||||||
|
|
||||||
from stable_baselines3.common.policies import ContinuousCritic
|
from stable_baselines3.common.policies import ContinuousCritic
|
||||||
|
|
||||||
from stable_baselines3.common.type_aliases import Schedule
|
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||||
|
|
||||||
from .distributions import make_proba_distribution
|
from .distributions import make_proba_distribution
|
||||||
from priorConditionedAnnealing import PCA_Distribution
|
from metastable_baselines2.pca import PCA_Distribution
|
||||||
|
|
||||||
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
||||||
|
|
||||||
@ -773,10 +773,12 @@ class Actor(BasePolicy):
|
|||||||
dividing by 255.0 (True by default)
|
dividing by 255.0 (True by default)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
action_space: spaces.Box
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Box,
|
||||||
net_arch: List[int],
|
net_arch: List[int],
|
||||||
features_extractor: nn.Module,
|
features_extractor: nn.Module,
|
||||||
features_dim: int,
|
features_dim: int,
|
||||||
@ -894,7 +896,7 @@ class Actor(BasePolicy):
|
|||||||
else:
|
else:
|
||||||
self.action_dist.base_noise.reset()
|
self.action_dist.base_noise.reset()
|
||||||
|
|
||||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Get the parameters for the action distribution.
|
Get the parameters for the action distribution.
|
||||||
|
|
||||||
@ -915,17 +917,17 @@ class Actor(BasePolicy):
|
|||||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
return mean_actions, log_std, {}
|
return mean_actions, log_std, {}
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# Note: the action is squashed
|
# Note: the action is squashed
|
||||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||||
|
|
||||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# return action and associated log prob
|
# return action and associated log prob
|
||||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
return self(observation, deterministic)
|
return self(observation, deterministic)
|
||||||
|
|
||||||
|
|
||||||
@ -958,10 +960,14 @@ class SACPolicy(BasePolicy):
|
|||||||
between the actor and the critic (this saves computation time)
|
between the actor and the critic (this saves computation time)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
actor: Actor
|
||||||
|
critic: ContinuousCritic
|
||||||
|
critic_target: ContinuousCritic
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Box,
|
||||||
lr_schedule: Schedule,
|
lr_schedule: Schedule,
|
||||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||||
@ -1049,7 +1055,7 @@ class SACPolicy(BasePolicy):
|
|||||||
# Create a separate features extractor for the critic
|
# Create a separate features extractor for the critic
|
||||||
# this requires more memory and computation
|
# this requires more memory and computation
|
||||||
self.critic = self.make_critic(features_extractor=None)
|
self.critic = self.make_critic(features_extractor=None)
|
||||||
critic_parameters = self.critic.parameters()
|
critic_parameters = list(self.critic.parameters())
|
||||||
|
|
||||||
# Critic target should not share the features extractor with critic
|
# Critic target should not share the features extractor with critic
|
||||||
self.critic_target = self.make_critic(features_extractor=None)
|
self.critic_target = self.make_critic(features_extractor=None)
|
||||||
@ -1100,10 +1106,10 @@ class SACPolicy(BasePolicy):
|
|||||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
return self._predict(obs, deterministic=deterministic)
|
return self._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||||
return self.actor(observation, deterministic)
|
return self.actor(observation, deterministic)
|
||||||
|
|
||||||
def set_training_mode(self, mode: bool) -> None:
|
def set_training_mode(self, mode: bool) -> None:
|
||||||
@ -1117,3 +1123,5 @@ class SACPolicy(BasePolicy):
|
|||||||
self.actor.set_training_mode(mode)
|
self.actor.set_training_mode(mode)
|
||||||
self.critic.set_training_mode(mode)
|
self.critic.set_training_mode(mode)
|
||||||
self.training = mode
|
self.training = mode
|
||||||
|
|
||||||
|
SACMlpPolicy = SACPolicy
|
5
metastable_baselines2/ppo/__init__.py
Normal file
5
metastable_baselines2/ppo/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from metastable_baselines2.ppo.ppo import PPO
|
||||||
|
|
||||||
|
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
|
||||||
|
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"]
|
@ -1,14 +1,13 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from ..common.on_policy_algorithm import BetterOnPolicyAlgorithm
|
from ..common.on_policy_algorithm import BetterOnPolicyAlgorithm
|
||||||
# from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
|
||||||
from ..common.policies import ActorCriticPolicy, BasePolicy
|
from ..common.policies import ActorCriticPolicy, BasePolicy
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||||
@ -54,6 +53,9 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
instead of action noise exploration (default: False)
|
instead of action noise exploration (default: False)
|
||||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||||
Default: -1 (only sample at the beginning of the rollout)
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
|
:param use_pca: Wether to use Prior Conditioned Annealing
|
||||||
|
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||||
|
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
|
||||||
:param target_kl: Limit the KL divergence between updates,
|
:param target_kl: Limit the KL divergence between updates,
|
||||||
because the clipping is not enough to prevent large update
|
because the clipping is not enough to prevent large update
|
||||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||||
@ -70,7 +72,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||||
"MlpPolicy": ActorCriticPolicy
|
"MlpPolicy": ActorCriticPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,6 +95,8 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
use_pca: bool = False,
|
use_pca: bool = False,
|
||||||
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
@ -115,6 +119,8 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
use_pca=use_pca,
|
use_pca=use_pca,
|
||||||
|
rollout_buffer_class=rollout_buffer_class,
|
||||||
|
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||||
stats_window_size=stats_window_size,
|
stats_window_size=stats_window_size,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
@ -130,7 +136,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
print('[i] Using sbBrix version of PPO')
|
print('[i] Using metastable version of PPO')
|
||||||
|
|
||||||
# Sanity check, otherwise it will lead to noisy gradient and NaN
|
# Sanity check, otherwise it will lead to noisy gradient and NaN
|
||||||
# because of the advantage normalization
|
# because of the advantage normalization
|
1
metastable_baselines2/sac/__init__.py
Normal file
1
metastable_baselines2/sac/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from metastable_baselines2.sac.sac import SAC
|
@ -1,18 +1,18 @@
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from stable_baselines3.common.buffers import ReplayBuffer
|
from stable_baselines3.common.buffers import ReplayBuffer
|
||||||
from stable_baselines3.common.noise import ActionNoise
|
from stable_baselines3.common.noise import ActionNoise
|
||||||
# from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
# from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
from ..common.off_policy_algorithm import BetterOffPolicyAlgorithm
|
from ..common.off_policy_algorithm import BetterOffPolicyAlgorithm
|
||||||
from stable_baselines3.common.policies import BasePolicy
|
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
|
||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||||
from ..common.policies import SACPolicy
|
from ..common.policies import SACMlpPolicy, SACPolicy, Actor #, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||||
|
|
||||||
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
||||||
|
|
||||||
@ -78,11 +78,16 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
|
||||||
"MlpPolicy": SACPolicy,
|
"MlpPolicy": SACMlpPolicy,
|
||||||
"SACPolicy": SACPolicy,
|
"SACPolicy": SACPolicy,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policy: SACPolicy
|
||||||
|
actor: Actor
|
||||||
|
critic: ContinuousCritic
|
||||||
|
critic_target: ContinuousCritic
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[SACPolicy]],
|
policy: Union[str, Type[SACPolicy]],
|
||||||
@ -139,11 +144,11 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
use_sde_at_warmup=use_sde_at_warmup,
|
use_sde_at_warmup=use_sde_at_warmup,
|
||||||
use_pca=use_pca,
|
use_pca=use_pca,
|
||||||
optimize_memory_usage=optimize_memory_usage,
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
supported_action_spaces=(spaces.Box),
|
supported_action_spaces=(spaces.Box,),
|
||||||
support_multi_env=True,
|
support_multi_env=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print('[i] Using sbBrix version of SAC')
|
print('[i] Using metastable version of SAC')
|
||||||
|
|
||||||
self.target_entropy = target_entropy
|
self.target_entropy = target_entropy
|
||||||
self.log_ent_coef = None # type: Optional[th.Tensor]
|
self.log_ent_coef = None # type: Optional[th.Tensor]
|
||||||
@ -151,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
# Inverse of the reward scale
|
# Inverse of the reward scale
|
||||||
self.ent_coef = ent_coef
|
self.ent_coef = ent_coef
|
||||||
self.target_update_interval = target_update_interval
|
self.target_update_interval = target_update_interval
|
||||||
self.ent_coef_optimizer = None
|
self.ent_coef_optimizer = Optional[th.optim.Adam] = None
|
||||||
|
|
||||||
if _init_setup_model:
|
if _init_setup_model:
|
||||||
self._setup_model()
|
self._setup_model()
|
||||||
@ -165,7 +170,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
# Target entropy is used when learning the entropy coefficient
|
# Target entropy is used when learning the entropy coefficient
|
||||||
if self.target_entropy == "auto":
|
if self.target_entropy == "auto":
|
||||||
# automatically set target entropy if needed
|
# automatically set target entropy if needed
|
||||||
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
|
self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore
|
||||||
else:
|
else:
|
||||||
# Force conversion
|
# Force conversion
|
||||||
# this will also throw an error for unexpected string
|
# this will also throw an error for unexpected string
|
||||||
@ -212,7 +217,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
|
|
||||||
for gradient_step in range(gradient_steps):
|
for gradient_step in range(gradient_steps):
|
||||||
# Sample replay buffer
|
# Sample replay buffer
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||||
|
|
||||||
# We need to sample because `log_std` may have changed between two gradient steps
|
# We need to sample because `log_std` may have changed between two gradient steps
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
@ -223,7 +228,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
log_prob = log_prob.reshape(-1, 1)
|
log_prob = log_prob.reshape(-1, 1)
|
||||||
|
|
||||||
ent_coef_loss = None
|
ent_coef_loss = None
|
||||||
if self.ent_coef_optimizer is not None:
|
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
|
||||||
# Important: detach the variable from the graph
|
# Important: detach the variable from the graph
|
||||||
# so we don't change it with other losses
|
# so we don't change it with other losses
|
||||||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||||
@ -259,7 +264,8 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
|
|
||||||
# Compute critic loss
|
# Compute critic loss
|
||||||
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
|
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
|
||||||
critic_losses.append(critic_loss.item())
|
assert isinstance(critic_loss, th.Tensor) # for type checker
|
||||||
|
critic_losses.append(critic_loss.item()) # type: ignore[union-attr]
|
||||||
|
|
||||||
# Optimize the critic
|
# Optimize the critic
|
||||||
self.critic.optimizer.zero_grad()
|
self.critic.optimizer.zero_grad()
|
2
metastable_baselines2/trpl/__init__.py
Normal file
2
metastable_baselines2/trpl/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
|
from metastable_baselines2.trpl.trpl import TRPL
|
1
metastable_baselines2/trpl/trpl.py
Normal file
1
metastable_baselines2/trpl/trpl.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
pass
|
@ -1,7 +0,0 @@
|
|||||||
from sbBrix.ppo import PPO
|
|
||||||
from sbBrix.sac import SAC
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"PPO",
|
|
||||||
"SAC",
|
|
||||||
]
|
|
@ -1 +0,0 @@
|
|||||||
from sbBrix.ppo.ppo import PPO
|
|
@ -1 +0,0 @@
|
|||||||
from sbBrix.sac.sac import SAC
|
|
6
setup.py
6
setup.py
@ -1,12 +1,12 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='sbBrix',
|
name='metastable_baselines2',
|
||||||
version='1.0.0',
|
version='2.1.0.0',
|
||||||
# url='https://github.com/mypackage.git',
|
# url='https://github.com/mypackage.git',
|
||||||
# author='Author Name',
|
# author='Author Name',
|
||||||
# author_email='author@gmail.com',
|
# author_email='author@gmail.com',
|
||||||
# description='Description of my package',
|
# description='Description of my package',
|
||||||
packages=['.'],
|
packages=['.'],
|
||||||
install_requires=['gym', 'stable_baselines3==1.8.0'],
|
install_requires=['gymnasium', 'stable_baselines3==2.1.0'],
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user