Updating to new sb3 version
This commit is contained in:
parent
1c1a909d27
commit
865efe4221
@ -1,12 +1,5 @@
|
||||
from sbBrix.ppo import PPO
|
||||
from sbBrix.sac import SAC
|
||||
|
||||
try:
|
||||
import priorConditionedAnnealing as pca
|
||||
except ModuleNotFoundError:
|
||||
class pca():
|
||||
def PCA_Distribution(*args, **kwargs):
|
||||
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
|
||||
from metastable_baselines2.ppo import PPO
|
||||
from metastable_baselines2.sac import SAC
|
||||
|
||||
__all__ = [
|
||||
"PPO",
|
||||
|
@ -1,6 +1,5 @@
|
||||
from stable_baselines3.common.distributions import *
|
||||
from metastable_baselines2.pca import PCA_Distribution
|
||||
|
||||
from metastable_baselines2.common.pca import PCA_Distribution
|
||||
|
||||
def _patched_make_proba_distribution(
|
||||
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
@ -1,5 +1,29 @@
|
||||
from stable_baselines3.common.off_policy_algorithm import *
|
||||
import io
|
||||
import pathlib
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
|
||||
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
|
||||
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="BetterOffPolicyAlgorithm")
|
||||
|
||||
class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
"""
|
||||
@ -52,6 +76,8 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
:param supported_action_spaces: The action spaces supported by the algorithm.
|
||||
"""
|
||||
|
||||
actor: th.nn.Module
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[BasePolicy]],
|
||||
@ -68,7 +94,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
optimize_memory_usage: bool = False,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = {},
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
@ -217,6 +243,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
if truncate_last_traj:
|
||||
self.replay_buffer.truncate_last_trajectory()
|
||||
|
||||
# Update saved replay buffer device to match current setting, see GH#1561
|
||||
self.replay_buffer.device = self.device
|
||||
|
||||
def _setup_learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
@ -248,10 +277,21 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
|
||||
"to avoid that issue."
|
||||
)
|
||||
assert replay_buffer is not None # for mypy
|
||||
# Go to the previous index
|
||||
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
|
||||
replay_buffer.dones[pos] = True
|
||||
|
||||
assert self.env is not None, "You must set the environment before calling _setup_learn()"
|
||||
|
||||
# Vectorize action noise if needed
|
||||
if (
|
||||
self.action_noise is not None
|
||||
and self.env.num_envs > 1
|
||||
and not isinstance(self.action_noise, VectorizedActionNoise)
|
||||
):
|
||||
self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs)
|
||||
|
||||
return super()._setup_learn(
|
||||
total_timesteps,
|
||||
callback,
|
||||
@ -279,6 +319,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
assert self.env is not None, "You must set the environment before calling learn()"
|
||||
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
rollout = self.collect_rollouts(
|
||||
self.env,
|
||||
@ -290,7 +333,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
log_interval=log_interval,
|
||||
)
|
||||
|
||||
if rollout.continue_training is False:
|
||||
if not rollout.continue_training:
|
||||
break
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
@ -341,6 +384,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
# Note: when using continuous actions,
|
||||
# we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
assert self._last_obs is not None, "self._last_obs was not set"
|
||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
@ -364,6 +408,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
"""
|
||||
Write log.
|
||||
"""
|
||||
assert self.ep_info_buffer is not None
|
||||
assert self.ep_success_buffer is not None
|
||||
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
@ -520,7 +567,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
if not callback.on_step():
|
||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
|
@ -1,5 +1,22 @@
|
||||
from stable_baselines3.common.on_policy_algorithm import *
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
|
||||
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="BetterOnPolicyAlgorithm")
|
||||
|
||||
class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
"""
|
||||
@ -21,6 +38,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_pca: Whether to use PCA.
|
||||
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
|
||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||
the reported success rate, mean episode length, and mean reward over
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
@ -50,6 +70,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
use_pca: bool,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
monitor_wrapper: bool = True,
|
||||
@ -62,6 +84,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
):
|
||||
assert not (use_sde and use_pca)
|
||||
self.use_pca = use_pca
|
||||
|
||||
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
@ -77,6 +102,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
device=device,
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
#rollout_buffer_class = rollout_buffer_class,
|
||||
#rollout_buffer_kwargs = rollout_buffer_kwargs,
|
||||
# support_multi_env=True,
|
||||
seed=seed,
|
||||
stats_window_size=stats_window_size,
|
||||
@ -86,13 +113,20 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
_init_setup_model=_init_setup_model
|
||||
)
|
||||
|
||||
self.rollout_buffer_class = None
|
||||
self.rollout_buffer_kwargs = {}
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
|
||||
if self.rollout_buffer_class is None:
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.rollout_buffer_class = DictRolloutBuffer
|
||||
else:
|
||||
self.rollout_buffer_class = RolloutBuffer
|
||||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
self.rollout_buffer = self.rollout_buffer_class(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
@ -100,6 +134,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
**self.rollout_buffer_kwargs,
|
||||
)
|
||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||
self.observation_space,
|
||||
@ -158,6 +193,13 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.policy.squash_output:
|
||||
# Unscale the actions to match env bounds
|
||||
# if they were previously squashed (scaled in [-1, 1])
|
||||
clipped_actions = self.policy.unscale_action(clipped_actions)
|
||||
else:
|
||||
# Otherwise, clip the actions to avoid out of bound error
|
||||
# as we are sampling from an unbounded Gaussian distribution
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
@ -166,7 +208,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
if callback.on_step() is False:
|
||||
if not callback.on_step():
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
@ -199,6 +241,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||
|
||||
callback.update_locals(locals())
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
@ -234,7 +278,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
@ -242,6 +286,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
assert self.ep_info_buffer is not None
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
|
5
metastable_baselines2/common/pca.py
Normal file
5
metastable_baselines2/common/pca.py
Normal file
@ -0,0 +1,5 @@
|
||||
try:
|
||||
from priorConditionedAnnealing import PCA_Distribution
|
||||
except ModuleNotFoundError:
|
||||
def PCA_Distribution(*args, **kwargs):
|
||||
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
|
@ -34,11 +34,12 @@ from stable_baselines3.common.torch_layers import (
|
||||
|
||||
from stable_baselines3.common.policies import ContinuousCritic
|
||||
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||
|
||||
from .distributions import make_proba_distribution
|
||||
from metastable_baselines2.pca import PCA_Distribution
|
||||
from metastable_baselines2.common.pca import PCA_Distribution
|
||||
|
||||
|
||||
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
||||
|
||||
@ -896,7 +897,7 @@ class Actor(BasePolicy):
|
||||
else:
|
||||
self.action_dist.base_noise.reset()
|
||||
|
||||
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
"""
|
||||
Get the parameters for the action distribution.
|
||||
|
||||
@ -917,17 +918,17 @@ class Actor(BasePolicy):
|
||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, {}
|
||||
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# Note: the action is squashed
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# return action and associated log prob
|
||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self(observation, deterministic)
|
||||
|
||||
|
||||
@ -1106,10 +1107,10 @@ class SACPolicy(BasePolicy):
|
||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
|
@ -113,7 +113,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
use_pca: bool = False,
|
||||
stats_window_size: int = 100,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = {},
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
@ -156,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
||||
# Inverse of the reward scale
|
||||
self.ent_coef = ent_coef
|
||||
self.target_update_interval = target_update_interval
|
||||
self.ent_coef_optimizer = Optional[th.optim.Adam] = None
|
||||
self.ent_coef_optimizer = None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
Loading…
Reference in New Issue
Block a user