diff --git a/metastable_baselines2/__init__.py b/metastable_baselines2/__init__.py index 19bfa76..07d11b3 100644 --- a/metastable_baselines2/__init__.py +++ b/metastable_baselines2/__init__.py @@ -1,12 +1,5 @@ -from sbBrix.ppo import PPO -from sbBrix.sac import SAC - -try: - import priorConditionedAnnealing as pca -except ModuleNotFoundError: - class pca(): - def PCA_Distribution(*args, **kwargs): - raise Exception('PCA is not installed; cannot initialize PCA_Distribution.') +from metastable_baselines2.ppo import PPO +from metastable_baselines2.sac import SAC __all__ = [ "PPO", diff --git a/metastable_baselines2/common/distributions.py b/metastable_baselines2/common/distributions.py index 4daa98e..735ba2f 100644 --- a/metastable_baselines2/common/distributions.py +++ b/metastable_baselines2/common/distributions.py @@ -1,6 +1,5 @@ from stable_baselines3.common.distributions import * -from metastable_baselines2.pca import PCA_Distribution - +from metastable_baselines2.common.pca import PCA_Distribution def _patched_make_proba_distribution( action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None diff --git a/metastable_baselines2/common/off_policy_algorithm.py b/metastable_baselines2/common/off_policy_algorithm.py index 0720e33..b4df453 100644 --- a/metastable_baselines2/common/off_policy_algorithm.py +++ b/metastable_baselines2/common/off_policy_algorithm.py @@ -1,5 +1,29 @@ -from stable_baselines3.common.off_policy_algorithm import * +import io +import pathlib +import sys +import time +import warnings +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +import numpy as np +import torch as th +from gymnasium import spaces + +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.utils import safe_mean, should_collect_more_steps +from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.her.her_replay_buffer import HerReplayBuffer + +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm + +SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="BetterOffPolicyAlgorithm") class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): """ @@ -52,6 +76,8 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): :param supported_action_spaces: The action spaces supported by the algorithm. """ + actor: th.nn.Module + def __init__( self, policy: Union[str, Type[BasePolicy]], @@ -68,7 +94,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[Dict[str, Any]] = {}, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, verbose: int = 0, @@ -217,6 +243,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): if truncate_last_traj: self.replay_buffer.truncate_last_trajectory() + # Update saved replay buffer device to match current setting, see GH#1561 + self.replay_buffer.device = self.device + def _setup_learn( self, total_timesteps: int, @@ -248,10 +277,21 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`" "to avoid that issue." ) + assert replay_buffer is not None # for mypy # Go to the previous index pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size replay_buffer.dones[pos] = True + assert self.env is not None, "You must set the environment before calling _setup_learn()" + + # Vectorize action noise if needed + if ( + self.action_noise is not None + and self.env.num_envs > 1 + and not isinstance(self.action_noise, VectorizedActionNoise) + ): + self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs) + return super()._setup_learn( total_timesteps, callback, @@ -279,6 +319,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): callback.on_training_start(locals(), globals()) + assert self.env is not None, "You must set the environment before calling learn()" + assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn() + while self.num_timesteps < total_timesteps: rollout = self.collect_rollouts( self.env, @@ -290,7 +333,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): log_interval=log_interval, ) - if rollout.continue_training is False: + if not rollout.continue_training: break if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: @@ -341,6 +384,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): # Note: when using continuous actions, # we assume that the policy uses tanh to scale the action # We use non-deterministic action in the case of SAC, for TD3, it does not matter + assert self._last_obs is not None, "self._last_obs was not set" unscaled_action, _ = self.predict(self._last_obs, deterministic=False) # Rescale the action from [low, high] to [-1, 1] @@ -364,6 +408,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): """ Write log. """ + assert self.ep_info_buffer is not None + assert self.ep_success_buffer is not None + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") @@ -520,7 +567,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): # Give access to local variables callback.update_locals(locals()) # Only stop training if return value is False, not when it is None. - if callback.on_step() is False: + if not callback.on_step(): return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False) # Retrieve reward and episode length if using Monitor wrapper diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index a6d6e88..c2f9317 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -1,5 +1,22 @@ -from stable_baselines3.common.on_policy_algorithm import * +import sys +import time +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +import numpy as np +import torch as th +from gymnasium import spaces + +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import obs_as_tensor, safe_mean +from stable_baselines3.common.vec_env import VecEnv + +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm + +SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="BetterOnPolicyAlgorithm") class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): """ @@ -21,6 +38,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param use_pca: Whether to use PCA. + :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. + :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation. :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) @@ -50,6 +70,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): use_sde: bool, sde_sample_freq: int, use_pca: bool, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, monitor_wrapper: bool = True, @@ -62,6 +84,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): ): assert not (use_sde and use_pca) self.use_pca = use_pca + + assert not rollout_buffer_class and not rollout_buffer_kwargs + super().__init__( policy=policy, env=env, @@ -77,6 +102,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): device=device, use_sde=use_sde, sde_sample_freq=sde_sample_freq, + #rollout_buffer_class = rollout_buffer_class, + #rollout_buffer_kwargs = rollout_buffer_kwargs, # support_multi_env=True, seed=seed, stats_window_size=stats_window_size, @@ -86,13 +113,20 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): _init_setup_model=_init_setup_model ) + self.rollout_buffer_class = None + self.rollout_buffer_kwargs = {} + def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer + if self.rollout_buffer_class is None: + if isinstance(self.observation_space, spaces.Dict): + self.rollout_buffer_class = DictRolloutBuffer + else: + self.rollout_buffer_class = RolloutBuffer - self.rollout_buffer = buffer_cls( + self.rollout_buffer = self.rollout_buffer_class( self.n_steps, self.observation_space, self.action_space, @@ -100,6 +134,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, + **self.rollout_buffer_kwargs, ) self.policy = self.policy_class( # pytype:disable=not-instantiable self.observation_space, @@ -158,7 +193,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.action_space, spaces.Box): - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + if self.policy.squash_output: + # Unscale the actions to match env bounds + # if they were previously squashed (scaled in [-1, 1]) + clipped_actions = self.policy.unscale_action(clipped_actions) + else: + # Otherwise, clip the actions to avoid out of bound error + # as we are sampling from an unbounded Gaussian distribution + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -166,7 +208,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): # Give access to local variables callback.update_locals(locals()) - if callback.on_step() is False: + if not callback.on_step(): return False self._update_info_buffer(infos) @@ -199,6 +241,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + callback.update_locals(locals()) + callback.on_rollout_end() return True @@ -234,7 +278,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) - if continue_training is False: + if not continue_training: break iteration += 1 @@ -242,6 +286,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): # Display training infos if log_interval is not None and iteration % log_interval == 0: + assert self.ep_info_buffer is not None time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") diff --git a/metastable_baselines2/common/pca.py b/metastable_baselines2/common/pca.py new file mode 100644 index 0000000..fde6fa0 --- /dev/null +++ b/metastable_baselines2/common/pca.py @@ -0,0 +1,5 @@ +try: + from priorConditionedAnnealing import PCA_Distribution +except ModuleNotFoundError: + def PCA_Distribution(*args, **kwargs): + raise Exception('PCA is not installed; cannot initialize PCA_Distribution.') \ No newline at end of file diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 0867f18..e46ded6 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -34,11 +34,12 @@ from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.policies import ContinuousCritic -from stable_baselines3.common.type_aliases import PyTorchObs, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor from .distributions import make_proba_distribution -from metastable_baselines2.pca import PCA_Distribution +from metastable_baselines2.common.pca import PCA_Distribution + SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel") @@ -896,7 +897,7 @@ class Actor(BasePolicy): else: self.action_dist.base_noise.reset() - def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """ Get the parameters for the action distribution. @@ -917,17 +918,17 @@ class Actor(BasePolicy): log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} - def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) - def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]: + def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) - def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self(observation, deterministic) @@ -1106,10 +1107,10 @@ class SACPolicy(BasePolicy): critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) return ContinuousCritic(**critic_kwargs).to(self.device) - def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) - def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self.actor(observation, deterministic) def set_training_mode(self, mode: bool) -> None: diff --git a/metastable_baselines2/sac/sac.py b/metastable_baselines2/sac/sac.py index 4db8c62..d51da05 100644 --- a/metastable_baselines2/sac/sac.py +++ b/metastable_baselines2/sac/sac.py @@ -113,7 +113,7 @@ class SAC(BetterOffPolicyAlgorithm): use_pca: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[Dict[str, Any]] = {}, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", @@ -156,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm): # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval - self.ent_coef_optimizer = Optional[th.optim.Adam] = None + self.ent_coef_optimizer = None if _init_setup_model: self._setup_model()