Updating to new sb3 version

This commit is contained in:
Dominik Moritz Roth 2023-11-19 18:34:15 +01:00
parent 1c1a909d27
commit 865efe4221
7 changed files with 121 additions and 31 deletions

View File

@ -1,12 +1,5 @@
from sbBrix.ppo import PPO
from sbBrix.sac import SAC
try:
import priorConditionedAnnealing as pca
except ModuleNotFoundError:
class pca():
def PCA_Distribution(*args, **kwargs):
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
from metastable_baselines2.ppo import PPO
from metastable_baselines2.sac import SAC
__all__ = [
"PPO",

View File

@ -1,6 +1,5 @@
from stable_baselines3.common.distributions import *
from metastable_baselines2.pca import PCA_Distribution
from metastable_baselines2.common.pca import PCA_Distribution
def _patched_make_proba_distribution(
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None

View File

@ -1,5 +1,29 @@
from stable_baselines3.common.off_policy_algorithm import *
import io
import pathlib
import sys
import time
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="BetterOffPolicyAlgorithm")
class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
"""
@ -52,6 +76,8 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
actor: th.nn.Module
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
@ -68,7 +94,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[Dict[str, Any]] = {},
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
@ -217,6 +243,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
if truncate_last_traj:
self.replay_buffer.truncate_last_trajectory()
# Update saved replay buffer device to match current setting, see GH#1561
self.replay_buffer.device = self.device
def _setup_learn(
self,
total_timesteps: int,
@ -248,10 +277,21 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
"to avoid that issue."
)
assert replay_buffer is not None # for mypy
# Go to the previous index
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
replay_buffer.dones[pos] = True
assert self.env is not None, "You must set the environment before calling _setup_learn()"
# Vectorize action noise if needed
if (
self.action_noise is not None
and self.env.num_envs > 1
and not isinstance(self.action_noise, VectorizedActionNoise)
):
self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs)
return super()._setup_learn(
total_timesteps,
callback,
@ -279,6 +319,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
callback.on_training_start(locals(), globals())
assert self.env is not None, "You must set the environment before calling learn()"
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()
while self.num_timesteps < total_timesteps:
rollout = self.collect_rollouts(
self.env,
@ -290,7 +333,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
log_interval=log_interval,
)
if rollout.continue_training is False:
if not rollout.continue_training:
break
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
@ -341,6 +384,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
# Note: when using continuous actions,
# we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
# Rescale the action from [low, high] to [-1, 1]
@ -364,6 +408,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
"""
Write log.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
@ -520,7 +567,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
# Give access to local variables
callback.update_locals(locals())
# Only stop training if return value is False, not when it is None.
if callback.on_step() is False:
if not callback.on_step():
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
# Retrieve reward and episode length if using Monitor wrapper

View File

@ -1,5 +1,22 @@
from stable_baselines3.common.on_policy_algorithm import *
import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="BetterOnPolicyAlgorithm")
class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
"""
@ -21,6 +38,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_pca: Whether to use PCA.
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
@ -50,6 +70,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
use_sde: bool,
sde_sample_freq: int,
use_pca: bool,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
@ -62,6 +84,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
):
assert not (use_sde and use_pca)
self.use_pca = use_pca
assert not rollout_buffer_class and not rollout_buffer_kwargs
super().__init__(
policy=policy,
env=env,
@ -77,6 +102,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
device=device,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
#rollout_buffer_class = rollout_buffer_class,
#rollout_buffer_kwargs = rollout_buffer_kwargs,
# support_multi_env=True,
seed=seed,
stats_window_size=stats_window_size,
@ -86,13 +113,20 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
_init_setup_model=_init_setup_model
)
self.rollout_buffer_class = None
self.rollout_buffer_kwargs = {}
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = DictRolloutBuffer
else:
self.rollout_buffer_class = RolloutBuffer
self.rollout_buffer = buffer_cls(
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space,
self.action_space,
@ -100,6 +134,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
@ -158,7 +193,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
@ -166,7 +208,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
if not callback.on_step():
return False
self._update_info_buffer(infos)
@ -199,6 +241,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.update_locals(locals())
callback.on_rollout_end()
return True
@ -234,7 +278,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False:
if not continue_training:
break
iteration += 1
@ -242,6 +286,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")

View File

@ -0,0 +1,5 @@
try:
from priorConditionedAnnealing import PCA_Distribution
except ModuleNotFoundError:
def PCA_Distribution(*args, **kwargs):
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')

View File

@ -34,11 +34,12 @@ from stable_baselines3.common.torch_layers import (
from stable_baselines3.common.policies import ContinuousCritic
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
from .distributions import make_proba_distribution
from metastable_baselines2.pca import PCA_Distribution
from metastable_baselines2.common.pca import PCA_Distribution
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
@ -896,7 +897,7 @@ class Actor(BasePolicy):
else:
self.action_dist.base_noise.reset()
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
@ -917,17 +918,17 @@ class Actor(BasePolicy):
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
@ -1106,10 +1107,10 @@ class SACPolicy(BasePolicy):
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
def set_training_mode(self, mode: bool) -> None:

View File

@ -113,7 +113,7 @@ class SAC(BetterOffPolicyAlgorithm):
use_pca: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[Dict[str, Any]] = {},
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
@ -156,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm):
# Inverse of the reward scale
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = Optional[th.optim.Adam] = None
self.ent_coef_optimizer = None
if _init_setup_model:
self._setup_model()