Implemented binding to PCA
This commit is contained in:
parent
497ee7e5fb
commit
9a4c43e233
45
sbBrix/common/distributions.py
Normal file
45
sbBrix/common/distributions.py
Normal file
@ -0,0 +1,45 @@
|
||||
from stable_baselines3.common.distributions import *
|
||||
from priorConditionedAnnealing import PCA_Distribution
|
||||
|
||||
|
||||
def _patched_make_proba_distribution(
|
||||
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Distribution:
|
||||
"""
|
||||
Return an instance of Distribution for the correct type of action space
|
||||
|
||||
:param action_space: the input action space
|
||||
:param use_sde: Force the use of StateDependentNoiseDistribution
|
||||
instead of DiagGaussianDistribution
|
||||
:param dist_kwargs: Keyword arguments to pass to the probability distribution
|
||||
:return: the appropriate Distribution object
|
||||
"""
|
||||
|
||||
assert not (use_sde and use_pca), 'Can not mix sde and pca!'
|
||||
|
||||
if dist_kwargs is None:
|
||||
dist_kwargs = {}
|
||||
|
||||
if isinstance(action_space, spaces.Box):
|
||||
if use_sde:
|
||||
cls = StateDependentNoiseDistribution
|
||||
elif use_pca:
|
||||
cls = PCA_Distribution
|
||||
else:
|
||||
cls = DiagGaussianDistribution
|
||||
return cls(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
return BernoulliDistribution(action_space.n, **dist_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Error: probability distribution, not implemented for action space"
|
||||
f"of type {type(action_space)}."
|
||||
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
|
||||
)
|
||||
|
||||
|
||||
_orig_make_propa_distribution, make_proba_distribution = make_proba_distribution, _patched_make_proba_distribution
|
554
sbBrix/common/off_policy_algorithm.py
Normal file
554
sbBrix/common/off_policy_algorithm.py
Normal file
@ -0,0 +1,554 @@
|
||||
from stable_baselines3.common.off_policy_algorithm import *
|
||||
|
||||
|
||||
class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
"""
|
||||
The base for Off-Policy algorithms (ex: SAC/TD3)
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param learning_rate: learning rate for the optimizer,
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: size of the replay buffer
|
||||
:param learning_starts: how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: Minibatch size for each gradient update
|
||||
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
|
||||
:param gamma: the discount factor
|
||||
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
|
||||
like ``(5, "step")`` or ``(2, "episode")``.
|
||||
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
|
||||
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
||||
during the rollout.
|
||||
:param action_noise: the action noise type (None by default), this can help
|
||||
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
|
||||
If ``None``, it will be automatically selected.
|
||||
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
|
||||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
||||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||
the reported success rate, mean episode length, and mean reward over
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
:param device: Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
if it is not possible.
|
||||
:param support_multi_env: Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param sde_support: Whether the model support gSDE or not
|
||||
:param supported_action_spaces: The action spaces supported by the algorithm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[BasePolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule],
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
|
||||
gradient_steps: int = 1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = "auto",
|
||||
support_multi_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False,
|
||||
sde_support: bool = True,
|
||||
use_pca: bool = False,
|
||||
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
|
||||
):
|
||||
assert not (use_sde and use_pca)
|
||||
self.use_pca = use_pca
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
learning_rate=learning_rate,
|
||||
policy_kwargs=policy_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
support_multi_env=support_multi_env,
|
||||
monitor_wrapper=monitor_wrapper,
|
||||
seed=seed,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
supported_action_spaces=supported_action_spaces,
|
||||
buffer_size=buffer_size,
|
||||
batch_size=batch_size,
|
||||
learning_starts=learning_starts,
|
||||
tau=tau,
|
||||
gamma=gamma,
|
||||
gradient_steps=gradient_steps,
|
||||
action_noise=action_noise,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
replay_buffer_class=replay_buffer_class,
|
||||
replay_buffer_kwargs=replay_buffer_kwargs,
|
||||
train_freq=train_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup,
|
||||
sde_support=sde_support,
|
||||
)
|
||||
|
||||
def _convert_train_freq(self) -> None:
|
||||
"""
|
||||
Convert `train_freq` parameter (int or tuple)
|
||||
to a TrainFreq object.
|
||||
"""
|
||||
if not isinstance(self.train_freq, TrainFreq):
|
||||
train_freq = self.train_freq
|
||||
|
||||
# The value of the train frequency will be checked later
|
||||
if not isinstance(train_freq, tuple):
|
||||
train_freq = (train_freq, "step")
|
||||
|
||||
try:
|
||||
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
|
||||
) from e
|
||||
|
||||
if not isinstance(train_freq[0], int):
|
||||
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
|
||||
|
||||
self.train_freq = TrainFreq(*train_freq)
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
if self.replay_buffer_class is None:
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.replay_buffer_class = DictReplayBuffer
|
||||
else:
|
||||
self.replay_buffer_class = ReplayBuffer
|
||||
|
||||
if self.replay_buffer is None:
|
||||
# Make a local copy as we should not pickle
|
||||
# the environment when using HerReplayBuffer
|
||||
replay_buffer_kwargs = self.replay_buffer_kwargs.copy()
|
||||
if issubclass(self.replay_buffer_class, HerReplayBuffer):
|
||||
assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"
|
||||
replay_buffer_kwargs["env"] = self.env
|
||||
self.replay_buffer = self.replay_buffer_class(
|
||||
self.buffer_size,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
device=self.device,
|
||||
n_envs=self.n_envs,
|
||||
optimize_memory_usage=self.optimize_memory_usage,
|
||||
**replay_buffer_kwargs, # pytype:disable=wrong-keyword-args
|
||||
)
|
||||
|
||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
**self.policy_kwargs, # pytype:disable=not-instantiable
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
# Convert train freq parameter to TrainFreq object
|
||||
self._convert_train_freq()
|
||||
|
||||
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
||||
:param path: Path to the file where the replay buffer should be saved.
|
||||
if path is a str or pathlib.Path, the path is automatically created if necessary.
|
||||
"""
|
||||
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
||||
save_to_pkl(path, self.replay_buffer, self.verbose)
|
||||
|
||||
def load_replay_buffer(
|
||||
self,
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
truncate_last_traj: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Load a replay buffer from a pickle file.
|
||||
|
||||
:param path: Path to the pickled replay buffer.
|
||||
:param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling:
|
||||
If set to ``True``, we assume that the last trajectory in the replay buffer was finished
|
||||
(and truncate it).
|
||||
If set to ``False``, we assume that we continue the same trajectory (same episode).
|
||||
"""
|
||||
self.replay_buffer = load_from_pkl(path, self.verbose)
|
||||
assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"
|
||||
|
||||
# Backward compatibility with SB3 < 2.1.0 replay buffer
|
||||
# Keep old behavior: do not handle timeout termination separately
|
||||
if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover
|
||||
self.replay_buffer.handle_timeout_termination = False
|
||||
self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones)
|
||||
|
||||
if isinstance(self.replay_buffer, HerReplayBuffer):
|
||||
assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
|
||||
self.replay_buffer.set_env(self.get_env())
|
||||
if truncate_last_traj:
|
||||
self.replay_buffer.truncate_last_trajectory()
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
tb_log_name: str = "run",
|
||||
progress_bar: bool = False,
|
||||
) -> Tuple[int, BaseCallback]:
|
||||
"""
|
||||
cf `BaseAlgorithm`.
|
||||
"""
|
||||
# Prevent continuity issue by truncating trajectory
|
||||
# when using memory efficient replay buffer
|
||||
# see https://github.com/DLR-RM/stable-baselines3/issues/46
|
||||
|
||||
replay_buffer = self.replay_buffer
|
||||
|
||||
truncate_last_traj = (
|
||||
self.optimize_memory_usage
|
||||
and reset_num_timesteps
|
||||
and replay_buffer is not None
|
||||
and (replay_buffer.full or replay_buffer.pos > 0)
|
||||
)
|
||||
|
||||
if truncate_last_traj:
|
||||
warnings.warn(
|
||||
"The last trajectory in the replay buffer will be truncated, "
|
||||
"see https://github.com/DLR-RM/stable-baselines3/issues/46."
|
||||
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
|
||||
"to avoid that issue."
|
||||
)
|
||||
# Go to the previous index
|
||||
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
|
||||
replay_buffer.dones[pos] = True
|
||||
|
||||
return super()._setup_learn(
|
||||
total_timesteps,
|
||||
callback,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
def learn(
|
||||
self: SelfOffPolicyAlgorithm,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
tb_log_name: str = "run",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfOffPolicyAlgorithm:
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
callback,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
rollout = self.collect_rollouts(
|
||||
self.env,
|
||||
train_freq=self.train_freq,
|
||||
action_noise=self.action_noise,
|
||||
callback=callback,
|
||||
learning_starts=self.learning_starts,
|
||||
replay_buffer=self.replay_buffer,
|
||||
log_interval=log_interval,
|
||||
)
|
||||
|
||||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
# If no `gradient_steps` is specified,
|
||||
# do as many gradients steps as steps performed during the rollout
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
|
||||
# Special case when the user passes `gradient_steps=0`
|
||||
if gradient_steps > 0:
|
||||
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int) -> None:
|
||||
"""
|
||||
Sample the replay buffer and do the updates
|
||||
(gradient descent and update target networks)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _sample_action(
|
||||
self,
|
||||
learning_starts: int,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
n_envs: int = 1,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Sample an action according to the exploration policy.
|
||||
This is either done by sampling the probability distribution of the policy,
|
||||
or sampling a random action (from a uniform distribution over the action space)
|
||||
or by adding noise to the deterministic output.
|
||||
|
||||
:param action_noise: Action noise that will be used for exploration
|
||||
Required for deterministic policy (e.g. TD3). This can also be used
|
||||
in addition to the stochastic policy for SAC.
|
||||
:param learning_starts: Number of steps before learning for the warm-up phase.
|
||||
:param n_envs:
|
||||
:return: action to take in the environment
|
||||
and scaled action that will be stored in the replay buffer.
|
||||
The two differs when the action space is not normalized (bounds are not [-1, 1]).
|
||||
"""
|
||||
# Select action randomly or according to policy
|
||||
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
|
||||
else:
|
||||
# Note: when using continuous actions,
|
||||
# we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
scaled_action = self.policy.scale_action(unscaled_action)
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
|
||||
|
||||
# We store the scaled action in the buffer
|
||||
buffer_action = scaled_action
|
||||
action = self.policy.unscale_action(scaled_action)
|
||||
else:
|
||||
# Discrete case, no need to normalize or clip
|
||||
buffer_action = unscaled_action
|
||||
action = buffer_action
|
||||
return action, buffer_action
|
||||
|
||||
def _dump_logs(self) -> None:
|
||||
"""
|
||||
Write log.
|
||||
"""
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
if self.use_sde or self.use_pca:
|
||||
self.logger.record("train/std", (self.actor.get_std()).mean().item())
|
||||
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
|
||||
# Pass the number of timesteps for tensorboard
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
def _on_step(self) -> None:
|
||||
"""
|
||||
Method called after each step in the environment.
|
||||
It is meant to trigger DQN target network update
|
||||
but can be used for other purposes
|
||||
"""
|
||||
pass
|
||||
|
||||
def _store_transition(
|
||||
self,
|
||||
replay_buffer: ReplayBuffer,
|
||||
buffer_action: np.ndarray,
|
||||
new_obs: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
reward: np.ndarray,
|
||||
dones: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Store transition in the replay buffer.
|
||||
We store the normalized action and the unnormalized observation.
|
||||
It also handles terminal observations (because VecEnv resets automatically).
|
||||
|
||||
:param replay_buffer: Replay buffer object where to store the transition.
|
||||
:param buffer_action: normalized action
|
||||
:param new_obs: next observation in the current episode
|
||||
or first observation of the episode (when dones is True)
|
||||
:param reward: reward for the current transition
|
||||
:param dones: Termination signal
|
||||
:param infos: List of additional information about the transition.
|
||||
It may contain the terminal observations and information about timeout.
|
||||
"""
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
|
||||
|
||||
# Avoid modification by reference
|
||||
next_obs = deepcopy(new_obs_)
|
||||
# As the VecEnv resets automatically, new_obs is already the
|
||||
# first observation of the next episode
|
||||
for i, done in enumerate(dones):
|
||||
if done and infos[i].get("terminal_observation") is not None:
|
||||
if isinstance(next_obs, dict):
|
||||
next_obs_ = infos[i]["terminal_observation"]
|
||||
# VecNormalize normalizes the terminal observation
|
||||
if self._vec_normalize_env is not None:
|
||||
next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_)
|
||||
# Replace next obs for the correct envs
|
||||
for key in next_obs.keys():
|
||||
next_obs[key][i] = next_obs_[key]
|
||||
else:
|
||||
next_obs[i] = infos[i]["terminal_observation"]
|
||||
# VecNormalize normalizes the terminal observation
|
||||
if self._vec_normalize_env is not None:
|
||||
next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])
|
||||
|
||||
replay_buffer.add(
|
||||
self._last_original_obs,
|
||||
next_obs,
|
||||
buffer_action,
|
||||
reward_,
|
||||
dones,
|
||||
infos,
|
||||
)
|
||||
|
||||
self._last_obs = new_obs
|
||||
# Save the unnormalized observation
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = new_obs_
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
train_freq: TrainFreq,
|
||||
replay_buffer: ReplayBuffer,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
learning_starts: int = 0,
|
||||
log_interval: Optional[int] = None,
|
||||
) -> RolloutReturn:
|
||||
"""
|
||||
Collect experiences and store them into a ``ReplayBuffer``.
|
||||
|
||||
:param env: The training environment
|
||||
:param callback: Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param train_freq: How much experience to collect
|
||||
by doing rollouts of current policy.
|
||||
Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
|
||||
or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
|
||||
with ``<n>`` being an integer greater than 0.
|
||||
:param action_noise: Action noise that will be used for exploration
|
||||
Required for deterministic policy (e.g. TD3). This can also be used
|
||||
in addition to the stochastic policy for SAC.
|
||||
:param learning_starts: Number of steps before learning for the warm-up phase.
|
||||
:param replay_buffer:
|
||||
:param log_interval: Log data every ``log_interval`` episodes
|
||||
:return:
|
||||
"""
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(False)
|
||||
|
||||
num_collected_steps, num_collected_episodes = 0, 0
|
||||
|
||||
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
||||
assert train_freq.frequency > 0, "Should at least collect one step or episode."
|
||||
|
||||
if env.num_envs > 1:
|
||||
assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
|
||||
|
||||
# Vectorize action noise if needed
|
||||
if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise):
|
||||
action_noise = VectorizedActionNoise(action_noise, env.num_envs)
|
||||
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise(env.num_envs)
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
||||
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise(env.num_envs)
|
||||
|
||||
# Select action randomly or according to policy
|
||||
actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
self.num_timesteps += env.num_envs
|
||||
num_collected_steps += 1
|
||||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos, dones)
|
||||
|
||||
# Store data in replay buffer (normalized action and unnormalized observation)
|
||||
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
|
||||
|
||||
self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
|
||||
|
||||
# For DQN, check if the target network should be updated
|
||||
# and update the exploration schedule
|
||||
# For SAC/TD3, the update is dones as the same time as the gradient update
|
||||
# see https://github.com/hill-a/stable-baselines/issues/900
|
||||
self._on_step()
|
||||
|
||||
for idx, done in enumerate(dones):
|
||||
if done:
|
||||
# Update stats
|
||||
num_collected_episodes += 1
|
||||
self._episode_num += 1
|
||||
|
||||
if action_noise is not None:
|
||||
kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
|
||||
action_noise.reset(**kwargs)
|
||||
|
||||
# Log training infos
|
||||
if log_interval is not None and self._episode_num % log_interval == 0:
|
||||
self._dump_logs()
|
||||
callback.on_rollout_end()
|
||||
|
||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
|
265
sbBrix/common/on_policy_algorithm.py
Normal file
265
sbBrix/common/on_policy_algorithm.py
Normal file
@ -0,0 +1,265 @@
|
||||
from stable_baselines3.common.on_policy_algorithm import *
|
||||
|
||||
|
||||
class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
"""
|
||||
The base for On-Policy algorithms (ex: A2C/PPO).
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: The learning rate, it can be a function
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param n_steps: The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param gamma: Discount factor
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param ent_coef: Entropy coefficient for the loss calculation
|
||||
:param vf_coef: Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: The maximum value for the gradient clipping
|
||||
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||
the reported success rate, mean episode length, and mean reward over
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
||||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
|
||||
debug messages
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
:param supported_action_spaces: The action spaces supported by the algorithm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule],
|
||||
n_steps: int,
|
||||
gamma: float,
|
||||
gae_lambda: float,
|
||||
ent_coef: float,
|
||||
vf_coef: float,
|
||||
max_grad_norm: float,
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
use_pca: bool,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
monitor_wrapper: bool = True,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
|
||||
):
|
||||
assert not (use_sde and use_pca)
|
||||
self.use_pca = use_pca
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
gamma=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
device=device,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
# support_multi_env=True,
|
||||
seed=seed,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
supported_action_spaces=supported_action_spaces,
|
||||
monitor_wrapper=monitor_wrapper,
|
||||
_init_setup_model=_init_setup_model
|
||||
)
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
|
||||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
device=self.device,
|
||||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
)
|
||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
use_sde=self.use_sde,
|
||||
use_pca=self.use_pca,
|
||||
**self.policy_kwargs # pytype:disable=not-instantiable
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def collect_rollouts(
|
||||
self,
|
||||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
rollout_buffer: RolloutBuffer,
|
||||
n_rollout_steps: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Collect experiences using the current policy and fill a ``RolloutBuffer``.
|
||||
The term rollout here refers to the model-free notion and should not
|
||||
be used with the concept of rollout used in model-based RL or planning.
|
||||
|
||||
:param env: The training environment
|
||||
:param callback: Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param rollout_buffer: Buffer to fill with rollouts
|
||||
:param n_rollout_steps: Number of experiences to collect per environment
|
||||
:return: True if function returned with at least `n_rollout_steps`
|
||||
collected, False if callback terminated rollout prematurely.
|
||||
"""
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.policy.set_training_mode(False)
|
||||
|
||||
n_steps = 0
|
||||
rollout_buffer.reset()
|
||||
# Sample new weights for the state dependent exploration
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
callback.on_rollout_start()
|
||||
|
||||
while n_steps < n_rollout_steps:
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor or to TensorDict
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
actions, values, log_probs = self.policy(obs_tensor)
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale and perform action
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
||||
self.num_timesteps += env.num_envs
|
||||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
if callback.on_step() is False:
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
|
||||
# Handle timeout by bootstraping with value function
|
||||
# see GitHub issue #633
|
||||
for idx, done in enumerate(dones):
|
||||
if (
|
||||
done
|
||||
and infos[idx].get("terminal_observation") is not None
|
||||
and infos[idx].get("TimeLimit.truncated", False)
|
||||
):
|
||||
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
||||
with th.no_grad():
|
||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
with th.no_grad():
|
||||
# Compute value for the last timestep
|
||||
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Consume current rollout data and update policy parameters.
|
||||
Implemented by individual algorithms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def learn(
|
||||
self: SelfOnPolicyAlgorithm,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "OnPolicyAlgorithm",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfOnPolicyAlgorithm:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
callback,
|
||||
reset_num_timesteps,
|
||||
tb_log_name,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
self.train()
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
return state_dicts, []
|
1084
sbBrix/common/policies.py
Normal file
1084
sbBrix/common/policies.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,15 +6,17 @@ import torch as th
|
||||
from gym import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from ..common.on_policy_algorithm import BetterOnPolicyAlgorithm
|
||||
# from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||
from ..common.policies import ActorCriticPolicy, BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
|
||||
SelfPPO = TypeVar("SelfPPO", bound="PPO")
|
||||
|
||||
|
||||
class PPO(OnPolicyAlgorithm):
|
||||
class PPO(BetterOnPolicyAlgorithm):
|
||||
"""
|
||||
Proximal Policy Optimization algorithm (PPO) (clip version)
|
||||
|
||||
@ -69,9 +71,7 @@ class PPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": ActorCriticPolicy,
|
||||
"CnnPolicy": ActorCriticCnnPolicy,
|
||||
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
||||
"MlpPolicy": ActorCriticPolicy
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -92,6 +92,7 @@ class PPO(OnPolicyAlgorithm):
|
||||
max_grad_norm: float = 0.5,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_pca: bool = False,
|
||||
target_kl: Optional[float] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
@ -113,6 +114,7 @@ class PPO(OnPolicyAlgorithm):
|
||||
max_grad_norm=max_grad_norm,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
use_pca=use_pca,
|
||||
stats_window_size=stats_window_size,
|
||||
tensorboard_log=tensorboard_log,
|
||||
policy_kwargs=policy_kwargs,
|
||||
@ -315,4 +317,3 @@ class PPO(OnPolicyAlgorithm):
|
||||
reset_num_timesteps=reset_num_timesteps,
|
||||
progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,8 @@ from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
# from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from common.off_policy_algorithm import BetterOffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
@ -16,7 +17,7 @@ from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolic
|
||||
SelfSAC = TypeVar("SelfSAC", bound="SAC")
|
||||
|
||||
|
||||
class SAC(OffPolicyAlgorithm):
|
||||
class SAC(BetterOffPolicyAlgorithm):
|
||||
"""
|
||||
Soft Actor-Critic (SAC)
|
||||
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
||||
@ -321,4 +322,3 @@ class SAC(OffPolicyAlgorithm):
|
||||
else:
|
||||
saved_pytorch_variables = ["ent_coef_tensor"]
|
||||
return state_dicts, saved_pytorch_variables
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user