From 9a4c43e2337fed720a389d848f6904e19b1ea68e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 21 Aug 2023 16:43:41 +0200 Subject: [PATCH] Implemented binding to PCA --- sbBrix/common/distributions.py | 45 + sbBrix/common/off_policy_algorithm.py | 554 +++++++++++++ sbBrix/common/on_policy_algorithm.py | 265 ++++++ sbBrix/common/policies.py | 1084 +++++++++++++++++++++++++ sbBrix/ppo/ppo.py | 15 +- sbBrix/sac/sac.py | 6 +- 6 files changed, 1959 insertions(+), 10 deletions(-) create mode 100644 sbBrix/common/distributions.py create mode 100644 sbBrix/common/off_policy_algorithm.py create mode 100644 sbBrix/common/on_policy_algorithm.py create mode 100644 sbBrix/common/policies.py diff --git a/sbBrix/common/distributions.py b/sbBrix/common/distributions.py new file mode 100644 index 0000000..9e48ac4 --- /dev/null +++ b/sbBrix/common/distributions.py @@ -0,0 +1,45 @@ +from stable_baselines3.common.distributions import * +from priorConditionedAnnealing import PCA_Distribution + + +def _patched_make_proba_distribution( + action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None +) -> Distribution: + """ + Return an instance of Distribution for the correct type of action space + + :param action_space: the input action space + :param use_sde: Force the use of StateDependentNoiseDistribution + instead of DiagGaussianDistribution + :param dist_kwargs: Keyword arguments to pass to the probability distribution + :return: the appropriate Distribution object + """ + + assert not (use_sde and use_pca), 'Can not mix sde and pca!' + + if dist_kwargs is None: + dist_kwargs = {} + + if isinstance(action_space, spaces.Box): + if use_sde: + cls = StateDependentNoiseDistribution + elif use_pca: + cls = PCA_Distribution + else: + cls = DiagGaussianDistribution + return cls(get_action_dim(action_space), **dist_kwargs) + elif isinstance(action_space, spaces.Discrete): + return CategoricalDistribution(action_space.n, **dist_kwargs) + elif isinstance(action_space, spaces.MultiDiscrete): + return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs) + elif isinstance(action_space, spaces.MultiBinary): + return BernoulliDistribution(action_space.n, **dist_kwargs) + else: + raise NotImplementedError( + "Error: probability distribution, not implemented for action space" + f"of type {type(action_space)}." + " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary." + ) + + +_orig_make_propa_distribution, make_proba_distribution = make_proba_distribution, _patched_make_proba_distribution diff --git a/sbBrix/common/off_policy_algorithm.py b/sbBrix/common/off_policy_algorithm.py new file mode 100644 index 0000000..4d5e0dd --- /dev/null +++ b/sbBrix/common/off_policy_algorithm.py @@ -0,0 +1,554 @@ +from stable_baselines3.common.off_policy_algorithm import * + + +class BetterOffPolicyAlgorithm(OffPolicyAlgorithm): + """ + The base for Off-Policy algorithms (ex: SAC/TD3) + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from + (if registered in Gym, can be str. Can be None for loading trained models) + :param learning_rate: learning rate for the optimizer, + it can be a function of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param action_noise: the action noise type (None by default), this can help + for hard exploration problem. Cf common.noise for the different action noise type. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param policy_kwargs: Additional arguments to be passed to the policy on creation + :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average + the reported success rate, mean episode length, and mean reward over + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for + debug messages + :param device: Device on which the code should run. + By default, it will try to use a Cuda compatible device and fallback to cpu + if it is not possible. + :param support_multi_env: Whether the algorithm supports training + with multiple environments (as in A2C) + :param monitor_wrapper: When creating an environment, whether to wrap it + or not in a Monitor wrapper. + :param seed: Seed for the pseudo random generators + :param use_sde: Whether to use State Dependent Exploration (SDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling + during the warm up phase (before learning starts) + :param sde_support: Whether the model support gSDE or not + :param supported_action_spaces: The action spaces supported by the algorithm. + """ + + def __init__( + self, + policy: Union[str, Type[BasePolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule], + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = (1, "step"), + gradient_steps: int = 1, + action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + verbose: int = 0, + device: Union[th.device, str] = "auto", + support_multi_env: bool = False, + monitor_wrapper: bool = True, + seed: Optional[int] = None, + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + sde_support: bool = True, + use_pca: bool = False, + supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, + ): + assert not (use_sde and use_pca) + self.use_pca = use_pca + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + policy_kwargs=policy_kwargs, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + support_multi_env=support_multi_env, + monitor_wrapper=monitor_wrapper, + seed=seed, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + supported_action_spaces=supported_action_spaces, + buffer_size=buffer_size, + batch_size=batch_size, + learning_starts=learning_starts, + tau=tau, + gamma=gamma, + gradient_steps=gradient_steps, + action_noise=action_noise, + optimize_memory_usage=optimize_memory_usage, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + train_freq=train_freq, + use_sde_at_warmup=use_sde_at_warmup, + sde_support=sde_support, + ) + + def _convert_train_freq(self) -> None: + """ + Convert `train_freq` parameter (int or tuple) + to a TrainFreq object. + """ + if not isinstance(self.train_freq, TrainFreq): + train_freq = self.train_freq + + # The value of the train frequency will be checked later + if not isinstance(train_freq, tuple): + train_freq = (train_freq, "step") + + try: + train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) + except ValueError as e: + raise ValueError( + f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!" + ) from e + + if not isinstance(train_freq[0], int): + raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") + + self.train_freq = TrainFreq(*train_freq) + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + if self.replay_buffer_class is None: + if isinstance(self.observation_space, spaces.Dict): + self.replay_buffer_class = DictReplayBuffer + else: + self.replay_buffer_class = ReplayBuffer + + if self.replay_buffer is None: + # Make a local copy as we should not pickle + # the environment when using HerReplayBuffer + replay_buffer_kwargs = self.replay_buffer_kwargs.copy() + if issubclass(self.replay_buffer_class, HerReplayBuffer): + assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`" + replay_buffer_kwargs["env"] = self.env + self.replay_buffer = self.replay_buffer_class( + self.buffer_size, + self.observation_space, + self.action_space, + device=self.device, + n_envs=self.n_envs, + optimize_memory_usage=self.optimize_memory_usage, + **replay_buffer_kwargs, # pytype:disable=wrong-keyword-args + ) + + self.policy = self.policy_class( # pytype:disable=not-instantiable + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, # pytype:disable=not-instantiable + ) + self.policy = self.policy.to(self.device) + + # Convert train freq parameter to TrainFreq object + self._convert_train_freq() + + def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: + """ + Save the replay buffer as a pickle file. + + :param path: Path to the file where the replay buffer should be saved. + if path is a str or pathlib.Path, the path is automatically created if necessary. + """ + assert self.replay_buffer is not None, "The replay buffer is not defined" + save_to_pkl(path, self.replay_buffer, self.verbose) + + def load_replay_buffer( + self, + path: Union[str, pathlib.Path, io.BufferedIOBase], + truncate_last_traj: bool = True, + ) -> None: + """ + Load a replay buffer from a pickle file. + + :param path: Path to the pickled replay buffer. + :param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling: + If set to ``True``, we assume that the last trajectory in the replay buffer was finished + (and truncate it). + If set to ``False``, we assume that we continue the same trajectory (same episode). + """ + self.replay_buffer = load_from_pkl(path, self.verbose) + assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" + + # Backward compatibility with SB3 < 2.1.0 replay buffer + # Keep old behavior: do not handle timeout termination separately + if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover + self.replay_buffer.handle_timeout_termination = False + self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones) + + if isinstance(self.replay_buffer, HerReplayBuffer): + assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`" + self.replay_buffer.set_env(self.get_env()) + if truncate_last_traj: + self.replay_buffer.truncate_last_trajectory() + + def _setup_learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + reset_num_timesteps: bool = True, + tb_log_name: str = "run", + progress_bar: bool = False, + ) -> Tuple[int, BaseCallback]: + """ + cf `BaseAlgorithm`. + """ + # Prevent continuity issue by truncating trajectory + # when using memory efficient replay buffer + # see https://github.com/DLR-RM/stable-baselines3/issues/46 + + replay_buffer = self.replay_buffer + + truncate_last_traj = ( + self.optimize_memory_usage + and reset_num_timesteps + and replay_buffer is not None + and (replay_buffer.full or replay_buffer.pos > 0) + ) + + if truncate_last_traj: + warnings.warn( + "The last trajectory in the replay buffer will be truncated, " + "see https://github.com/DLR-RM/stable-baselines3/issues/46." + "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`" + "to avoid that issue." + ) + # Go to the previous index + pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size + replay_buffer.dones[pos] = True + + return super()._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + def learn( + self: SelfOffPolicyAlgorithm, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "run", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfOffPolicyAlgorithm: + total_timesteps, callback = self._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + rollout = self.collect_rollouts( + self.env, + train_freq=self.train_freq, + action_noise=self.action_noise, + callback=callback, + learning_starts=self.learning_starts, + replay_buffer=self.replay_buffer, + log_interval=log_interval, + ) + + if rollout.continue_training is False: + break + + if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: + # If no `gradient_steps` is specified, + # do as many gradients steps as steps performed during the rollout + gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps + # Special case when the user passes `gradient_steps=0` + if gradient_steps > 0: + self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) + + callback.on_training_end() + + return self + + def train(self, gradient_steps: int, batch_size: int) -> None: + """ + Sample the replay buffer and do the updates + (gradient descent and update target networks) + """ + raise NotImplementedError() + + def _sample_action( + self, + learning_starts: int, + action_noise: Optional[ActionNoise] = None, + n_envs: int = 1, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Sample an action according to the exploration policy. + This is either done by sampling the probability distribution of the policy, + or sampling a random action (from a uniform distribution over the action space) + or by adding noise to the deterministic output. + + :param action_noise: Action noise that will be used for exploration + Required for deterministic policy (e.g. TD3). This can also be used + in addition to the stochastic policy for SAC. + :param learning_starts: Number of steps before learning for the warm-up phase. + :param n_envs: + :return: action to take in the environment + and scaled action that will be stored in the replay buffer. + The two differs when the action space is not normalized (bounds are not [-1, 1]). + """ + # Select action randomly or according to policy + if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup): + # Warmup phase + unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)]) + else: + # Note: when using continuous actions, + # we assume that the policy uses tanh to scale the action + # We use non-deterministic action in the case of SAC, for TD3, it does not matter + unscaled_action, _ = self.predict(self._last_obs, deterministic=False) + + # Rescale the action from [low, high] to [-1, 1] + if isinstance(self.action_space, spaces.Box): + scaled_action = self.policy.scale_action(unscaled_action) + + # Add noise to the action (improve exploration) + if action_noise is not None: + scaled_action = np.clip(scaled_action + action_noise(), -1, 1) + + # We store the scaled action in the buffer + buffer_action = scaled_action + action = self.policy.unscale_action(scaled_action) + else: + # Discrete case, no need to normalize or clip + buffer_action = unscaled_action + action = buffer_action + return action, buffer_action + + def _dump_logs(self) -> None: + """ + Write log. + """ + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + if self.use_sde or self.use_pca: + self.logger.record("train/std", (self.actor.get_std()).mean().item()) + + if len(self.ep_success_buffer) > 0: + self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer)) + # Pass the number of timesteps for tensorboard + self.logger.dump(step=self.num_timesteps) + + def _on_step(self) -> None: + """ + Method called after each step in the environment. + It is meant to trigger DQN target network update + but can be used for other purposes + """ + pass + + def _store_transition( + self, + replay_buffer: ReplayBuffer, + buffer_action: np.ndarray, + new_obs: Union[np.ndarray, Dict[str, np.ndarray]], + reward: np.ndarray, + dones: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + """ + Store transition in the replay buffer. + We store the normalized action and the unnormalized observation. + It also handles terminal observations (because VecEnv resets automatically). + + :param replay_buffer: Replay buffer object where to store the transition. + :param buffer_action: normalized action + :param new_obs: next observation in the current episode + or first observation of the episode (when dones is True) + :param reward: reward for the current transition + :param dones: Termination signal + :param infos: List of additional information about the transition. + It may contain the terminal observations and information about timeout. + """ + # Store only the unnormalized version + if self._vec_normalize_env is not None: + new_obs_ = self._vec_normalize_env.get_original_obs() + reward_ = self._vec_normalize_env.get_original_reward() + else: + # Avoid changing the original ones + self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward + + # Avoid modification by reference + next_obs = deepcopy(new_obs_) + # As the VecEnv resets automatically, new_obs is already the + # first observation of the next episode + for i, done in enumerate(dones): + if done and infos[i].get("terminal_observation") is not None: + if isinstance(next_obs, dict): + next_obs_ = infos[i]["terminal_observation"] + # VecNormalize normalizes the terminal observation + if self._vec_normalize_env is not None: + next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_) + # Replace next obs for the correct envs + for key in next_obs.keys(): + next_obs[key][i] = next_obs_[key] + else: + next_obs[i] = infos[i]["terminal_observation"] + # VecNormalize normalizes the terminal observation + if self._vec_normalize_env is not None: + next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :]) + + replay_buffer.add( + self._last_original_obs, + next_obs, + buffer_action, + reward_, + dones, + infos, + ) + + self._last_obs = new_obs + # Save the unnormalized observation + if self._vec_normalize_env is not None: + self._last_original_obs = new_obs_ + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + train_freq: TrainFreq, + replay_buffer: ReplayBuffer, + action_noise: Optional[ActionNoise] = None, + learning_starts: int = 0, + log_interval: Optional[int] = None, + ) -> RolloutReturn: + """ + Collect experiences and store them into a ``ReplayBuffer``. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param train_freq: How much experience to collect + by doing rollouts of current policy. + Either ``TrainFreq(, TrainFrequencyUnit.STEP)`` + or ``TrainFreq(, TrainFrequencyUnit.EPISODE)`` + with ```` being an integer greater than 0. + :param action_noise: Action noise that will be used for exploration + Required for deterministic policy (e.g. TD3). This can also be used + in addition to the stochastic policy for SAC. + :param learning_starts: Number of steps before learning for the warm-up phase. + :param replay_buffer: + :param log_interval: Log data every ``log_interval`` episodes + :return: + """ + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + num_collected_steps, num_collected_episodes = 0, 0 + + assert isinstance(env, VecEnv), "You must pass a VecEnv" + assert train_freq.frequency > 0, "Should at least collect one step or episode." + + if env.num_envs > 1: + assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training." + + # Vectorize action noise if needed + if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise): + action_noise = VectorizedActionNoise(action_noise, env.num_envs) + + if self.use_sde: + self.actor.reset_noise(env.num_envs) + + callback.on_rollout_start() + continue_training = True + while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): + if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.actor.reset_noise(env.num_envs) + + # Select action randomly or according to policy + actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs) + + # Rescale and perform action + new_obs, rewards, dones, infos = env.step(actions) + + self.num_timesteps += env.num_envs + num_collected_steps += 1 + + # Give access to local variables + callback.update_locals(locals()) + # Only stop training if return value is False, not when it is None. + if callback.on_step() is False: + return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False) + + # Retrieve reward and episode length if using Monitor wrapper + self._update_info_buffer(infos, dones) + + # Store data in replay buffer (normalized action and unnormalized observation) + self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) + + self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps) + + # For DQN, check if the target network should be updated + # and update the exploration schedule + # For SAC/TD3, the update is dones as the same time as the gradient update + # see https://github.com/hill-a/stable-baselines/issues/900 + self._on_step() + + for idx, done in enumerate(dones): + if done: + # Update stats + num_collected_episodes += 1 + self._episode_num += 1 + + if action_noise is not None: + kwargs = dict(indices=[idx]) if env.num_envs > 1 else {} + action_noise.reset(**kwargs) + + # Log training infos + if log_interval is not None and self._episode_num % log_interval == 0: + self._dump_logs() + callback.on_rollout_end() + + return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/sbBrix/common/on_policy_algorithm.py b/sbBrix/common/on_policy_algorithm.py new file mode 100644 index 0000000..a6d6e88 --- /dev/null +++ b/sbBrix/common/on_policy_algorithm.py @@ -0,0 +1,265 @@ +from stable_baselines3.common.on_policy_algorithm import * + + +class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): + """ + The base for On-Policy algorithms (ex: A2C/PPO). + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator. + Equivalent to classic advantage when set to 1. + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average + the reported success rate, mean episode length, and mean reward over + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param monitor_wrapper: When creating an environment, whether to wrap it + or not in a Monitor wrapper. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for + debug messages + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + :param supported_action_spaces: The action spaces supported by the algorithm. + """ + + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule], + n_steps: int, + gamma: float, + gae_lambda: float, + ent_coef: float, + vf_coef: float, + max_grad_norm: float, + use_sde: bool, + sde_sample_freq: int, + use_pca: bool, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + monitor_wrapper: bool = True, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, + ): + assert not (use_sde and use_pca) + self.use_pca = use_pca + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + # support_multi_env=True, + seed=seed, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + supported_action_spaces=supported_action_spaces, + monitor_wrapper=monitor_wrapper, + _init_setup_model=_init_setup_model + ) + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer + + self.rollout_buffer = buffer_cls( + self.n_steps, + self.observation_space, + self.action_space, + device=self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + self.policy = self.policy_class( # pytype:disable=not-instantiable + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + use_pca=self.use_pca, + **self.policy_kwargs # pytype:disable=not-instantiable + ) + self.policy = self.policy.to(self.device) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_rollout_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + actions, values, log_probs = self.policy(obs_tensor) + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) + self._last_obs = new_obs + self._last_episode_starts = dones + + with th.no_grad(): + # Compute value for the last timestep + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.on_rollout_end() + + return True + + def train(self) -> None: + """ + Consume current rollout data and update policy parameters. + Implemented by individual algorithms. + """ + raise NotImplementedError + + def learn( + self: SelfOnPolicyAlgorithm, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "OnPolicyAlgorithm", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfOnPolicyAlgorithm: + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) + + if continue_training is False: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + self.train() + + callback.on_training_end() + + return self + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + state_dicts = ["policy", "policy.optimizer"] + + return state_dicts, [] diff --git a/sbBrix/common/policies.py b/sbBrix/common/policies.py new file mode 100644 index 0000000..df1c1cc --- /dev/null +++ b/sbBrix/common/policies.py @@ -0,0 +1,1084 @@ +"""Policies: abstract base class and concrete implementations.""" + +import collections +import copy +import warnings +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union + +import numpy as np +import torch as th +from gym import spaces +from torch import nn + +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + DiagGaussianDistribution, + Distribution, + MultiCategoricalDistribution, + StateDependentNoiseDistribution, + SquashedDiagGaussianDistribution, +) +from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, + create_mlp, + get_actor_critic_arch, +) + +from stable_baselines3.common.policies import ContinuousCritic + +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor + +from .distributions import make_proba_distribution +from priorConditionedAnnealing import PCA_Distribution + +SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel") + + +class BaseModel(nn.Module): + """ + The base model object: makes predictions in response to observations. + + In the case of policies, the prediction is an action. In the case of critics, it is the + estimated value of the observation. + + :param observation_space: The observation space of the environment + :param action_space: The action space of the environment + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor: Optional[BaseFeaturesExtractor] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__() + + if optimizer_kwargs is None: + optimizer_kwargs = {} + + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + + self.observation_space = observation_space + self.action_space = action_space + self.features_extractor = features_extractor + self.normalize_images = normalize_images + + self.optimizer_class = optimizer_class + self.optimizer_kwargs = optimizer_kwargs + self.optimizer: th.optim.Optimizer + + self.features_extractor_class = features_extractor_class + self.features_extractor_kwargs = features_extractor_kwargs + # Automatically deactivate dtype and bounds checks + if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)): + self.features_extractor_kwargs.update(dict(normalized_image=True)) + + def _update_features_extractor( + self, + net_kwargs: Dict[str, Any], + features_extractor: Optional[BaseFeaturesExtractor] = None, + ) -> Dict[str, Any]: + """ + Update the network keyword arguments and create a new features extractor object if needed. + If a ``features_extractor`` object is passed, then it will be shared. + + :param net_kwargs: the base network keyword arguments, without the ones + related to features extractor + :param features_extractor: a features extractor object. + If None, a new object will be created. + :return: The updated keyword arguments + """ + net_kwargs = net_kwargs.copy() + if features_extractor is None: + # The features extractor is not shared, create a new one + features_extractor = self.make_features_extractor() + net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) + return net_kwargs + + def make_features_extractor(self) -> BaseFeaturesExtractor: + """Helper method to create a features extractor.""" + return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + + def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor: + """ + Preprocess the observation if needed and extract features. + + :param obs: The observation + :param features_extractor: The features extractor to use. + :return: The extracted features + """ + preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) + return features_extractor(preprocessed_obs) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the model when loading it from disk. + + :return: The dictionary to pass to the as kwargs constructor when reconstruction this model. + """ + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + # Passed to the constructor by child class + # squash_output=self.squash_output, + # features_extractor=self.features_extractor + normalize_images=self.normalize_images, + ) + + @property + def device(self) -> th.device: + """Infer which device this policy lives on by inspecting its parameters. + If it has no parameters, the 'cpu' device is used as a fallback. + + :return:""" + for param in self.parameters(): + return param.device + return get_device("cpu") + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: + """ + th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) + + @classmethod + def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel: + """ + Load model from path. + + :param path: + :param device: Device on which the policy should be loaded. + :return: + """ + device = get_device(device) + saved_variables = th.load(path, map_location=device) + + # Create policy object + model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable + # Load weights + model.load_state_dict(saved_variables["state_dict"]) + model.to(device) + return model + + def load_from_vector(self, vector: np.ndarray) -> None: + """ + Load parameters from a 1D vector. + + :param vector: + """ + th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters()) + + def parameters_to_vector(self) -> np.ndarray: + """ + Convert the parameters to a 1D vector. + + :return: + """ + return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() + + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.train(mode) + + def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool: + """ + Check whether or not the observation is vectorized, + apply transposition to image (so that they are channel-first) if needed. + This is used in DQN when sampling random action (epsilon-greedy policy) + + :param observation: the input observation to check + :return: whether the given observation is vectorized or not + """ + vectorized_env = False + if isinstance(observation, dict): + for key, obs in observation.items(): + obs_space = self.observation_space.spaces[key] + vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space) + else: + vectorized_env = is_vectorized_observation( + maybe_transpose(observation, self.observation_space), self.observation_space + ) + return vectorized_env + + def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]: + """ + Convert an input observation to a PyTorch tensor that can be fed to a model. + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :return: The observation as PyTorch tensor + and whether the observation is vectorized or not + """ + vectorized_env = False + if isinstance(observation, dict): + # need to copy the dict as the dict in VecFrameStack will become a torch tensor + observation = copy.deepcopy(observation) + for key, obs in observation.items(): + obs_space = self.observation_space.spaces[key] + if is_image_space(obs_space): + obs_ = maybe_transpose(obs, obs_space) + else: + obs_ = np.array(obs) + vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) + # Add batch dimension if needed + observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) + + elif is_image_space(self.observation_space): + # Handle the different cases for images + # as PyTorch use channel first format + observation = maybe_transpose(observation, self.observation_space) + + else: + observation = np.array(observation) + + if not isinstance(observation, dict): + # Dict obs need to be handled separately + vectorized_env = is_vectorized_observation(observation, self.observation_space) + # Add batch dimension if needed + observation = observation.reshape((-1, *self.observation_space.shape)) + + observation = obs_as_tensor(observation, self.device) + return observation, vectorized_env + + +class BasePolicy(BaseModel, ABC): + """The base policy object. + + Parameters are mostly the same as `BaseModel`; additions are documented below. + + :param args: positional arguments passed through to `BaseModel`. + :param kwargs: keyword arguments passed through to `BaseModel`. + :param squash_output: For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. + """ + + def __init__(self, *args, squash_output: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self._squash_output = squash_output + + @staticmethod + def _dummy_schedule(progress_remaining: float) -> float: + """(float) Useful for pickling policy.""" + del progress_remaining + return 0.0 + + @property + def squash_output(self) -> bool: + """(bool) Getter for squash_output.""" + return self._squash_output + + @staticmethod + def init_weights(module: nn.Module, gain: float = 1) -> None: + """ + Orthogonal initialization (used in PPO and A2C) + """ + if isinstance(module, (nn.Linear, nn.Conv2d)): + nn.init.orthogonal_(module.weight, gain=gain) + if module.bias is not None: + module.bias.data.fill_(0.0) + + @abstractmethod + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + By default provides a dummy implementation -- not all BasePolicy classes + implement this, e.g. if they are a Critic in an Actor-Critic method. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + with th.no_grad(): + actions = self._predict(observation, deterministic=deterministic) + # Convert to numpy, and reshape to the original action shape + actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) + + if isinstance(self.action_space, spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions.squeeze(axis=0) + + return actions, state + + def scale_action(self, action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [low, high] to [-1, 1] + (no need for symmetric action space) + + :param action: Action to scale + :return: Scaled action + """ + low, high = self.action_space.low, self.action_space.high + return 2.0 * ((action - low) / (high - low)) - 1.0 + + def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [-1, 1] to [low, high] + (no need for symmetric action space) + + :param scaled_action: Action to un-scale + """ + low, high = self.action_space.low, self.action_space.high + return low + (0.5 * (scaled_action + 1.0) * (high - low)) + + +class ActorCriticPolicy(BasePolicy): + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + use_pca: bool = False, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + dist_kwargs: Optional[Dict[str, Any]] = {}, + ): + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output, + normalize_images=normalize_images, + ) + + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): + warnings.warn( + ( + "As shared layers in the mlp_extractor are removed since SB3 v1.8.0, " + "you should now pass directly a dictionary and not a list " + "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" + ), + ) + net_arch = net_arch[0] + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = dict(pi=[64, 64], vf=[64, 64]) + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + self.share_features_extractor = share_features_extractor + self.features_extractor = self.make_features_extractor() + self.features_dim = self.features_extractor.features_dim + if self.share_features_extractor: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.features_extractor + else: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.make_features_extractor() + + self.log_std_init = log_std_init + # Keyword arguments for gSDE distribution + if use_sde: + add_dist_kwargs = { + "full_std": full_std, + "squash_output": squash_output, + "use_expln": use_expln, + "learn_features": False, + } + dist_kwargs.update(add_dist_kwargs) + + self.use_sde = use_sde + self.use_pca = use_pca + self.dist_kwargs = dist_kwargs + + # Action distribution + self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs) + + self._build(lr_schedule) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + squash_output=default_none_kwargs["squash_output"], + full_std=default_none_kwargs["full_std"], + use_expln=default_none_kwargs["use_expln"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, n_envs: int = 1) -> None: + """ + Sample new weights for the exploration matrix. + + :param n_envs: + """ + assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + # Note: If net_arch is None and some features extractor is used, + # net_arch here is an empty list and mlp_extractor does not + # really contain any layers (acts like an identity module). + self.mlp_extractor = MlpExtractor( + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the networks and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self._build_mlp_extractor() + + latent_dim_pi = self.mlp_extractor.latent_dim_pi + + if isinstance(self.action_dist, DiagGaussianDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + else: + raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") + + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + if not self.share_features_extractor: + # Note(antonin): this is to keep SB3 results + # consistent, see GH#1148 + del module_gains[self.features_extractor] + module_gains[self.pi_features_extractor] = np.sqrt(2) + module_gains[self.vf_features_extractor] = np.sqrt(2) + + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + actions = actions.reshape((-1, *self.action_space.shape)) + return actions, values, log_prob + + def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: + """ + Preprocess the observation if needed and extract features. + + :param obs: Observation + :return: the output of the features extractor(s) + """ + if self.share_features_extractor: + return super().extract_features(obs, self.features_extractor) + else: + pi_features = super().extract_features(obs, self.pi_features_extractor) + vf_features = super().extract_features(obs, self.vf_features_extractor) + return pi_features, vf_features + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :return: Action distribution + """ + mean_actions = self.action_net(latent_pi) + + if isinstance(self.action_dist, DiagGaussianDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std) + elif isinstance(self.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) + else: + raise ValueError("Invalid action distribution") + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + return self.get_distribution(observation).get_actions(deterministic=deterministic) + + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation + :param actions: Actions + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + entropy = distribution.entropy() + return values, log_prob, entropy + + def get_distribution(self, obs: th.Tensor) -> Distribution: + """ + Get the current policy distribution given the observations. + + :param obs: + :return: the action distribution. + """ + features = super().extract_features(obs, self.pi_features_extractor) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) + + def predict_values(self, obs: th.Tensor) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: Observation + :return: the estimated values. + """ + features = super().extract_features(obs, self.vf_features_extractor) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + + +class Actor(BasePolicy): + """ + Actor network (policy) for SAC. + + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: Number of features + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE. + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + use_pca: bool = False, + full_std: bool = True, + use_expln: bool = False, + clip_mean: float = 2.0, + normalize_images: bool = True, + dist_kwargs={}, + ): + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + squash_output=True, + ) + + # Save arguments to re-create object at loading + self.use_sde = use_sde + self.use_pca = use_pca + self.sde_features_extractor = None + self.net_arch = net_arch + self.features_dim = features_dim + self.activation_fn = activation_fn + self.log_std_init = log_std_init + self.use_expln = use_expln + self.full_std = full_std + self.clip_mean = clip_mean + + action_dim = get_action_dim(self.action_space) + latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) + self.latent_pi = nn.Sequential(*latent_pi_net) + last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim + + assert not self.use_sde and self.use_pca + + if self.use_pca: + self.action_dist = PCA_Distribution( + action_dim, **dist_kwargs + ) + self.mu, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init, **dist_kwargs + ) + # Avoid numerical issues by limiting the mean of the Gaussian + # to be in [-clip_mean, clip_mean] + if clip_mean > 0.0: + self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) + elif self.use_sde: + self.action_dist = StateDependentNoiseDistribution( + action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True + ) + self.mu, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init + ) + # Avoid numerical issues by limiting the mean of the Gaussian + # to be in [-clip_mean, clip_mean] + if clip_mean > 0.0: + self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) + else: + self.action_dist = SquashedDiagGaussianDistribution(action_dim) + self.mu = nn.Linear(last_layer_dim, action_dim) + self.log_std = nn.Linear(last_layer_dim, action_dim) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + full_std=self.full_std, + use_expln=self.use_expln, + features_extractor=self.features_extractor, + clip_mean=self.clip_mean, + ) + ) + return data + + def get_std(self) -> th.Tensor: + """ + Retrieve the standard deviation of the action distribution. + Only useful when using gSDE. + It corresponds to ``th.exp(log_std)`` in the normal case, + but is slightly different when using ``expln`` function + (cf StateDependentNoiseDistribution doc). + + :return: + """ + msg = "get_std() is only available when using gSDE" + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg + return self.action_dist.get_std(self.log_std) + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + + :param batch_size: + """ + msg = "reset_noise() is only available when using gSDE" + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg + self.action_dist.sample_weights(self.log_std, batch_size=batch_size) + + def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + """ + Get the parameters for the action distribution. + + :param obs: + :return: + Mean, standard deviation and optional keyword arguments. + """ + features = self.extract_features(obs, self.features_extractor) + latent_pi = self.latent_pi(features) + mean_actions = self.mu(latent_pi) + + if self.use_sde: + return mean_actions, self.log_std, dict(latent_sde=latent_pi) + # Unstructured exploration (Original implementation) + log_std = self.log_std(latent_pi) + # Original Implementation to cap the standard deviation + log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + return mean_actions, log_std, {} + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + # Note: the action is squashed + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) + + def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + # return action and associated log prob + return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self(observation, deterministic) + + +class SACPolicy(BasePolicy): + """ + Policy class (with both actor and critic) for SAC. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + use_pca: bool = False, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = False, + dist_kwargs={}, + ): + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + normalize_images=normalize_images, + ) + + if net_arch is None: + net_arch = [256, 256] + + actor_arch, critic_arch = get_actor_critic_arch(net_arch) + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.net_args = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "net_arch": actor_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + } + self.actor_kwargs = self.net_args.copy() + + sde_kwargs = { + "use_sde": use_sde, + "use_pca": use_pca, + "log_std_init": log_std_init, + "use_expln": use_expln, + "clip_mean": clip_mean, + "dist_kwargs": dist_kwargs, + } + + self.actor_kwargs.update(sde_kwargs) + self.critic_kwargs = self.net_args.copy() + self.critic_kwargs.update( + { + "n_critics": n_critics, + "net_arch": critic_arch, + "share_features_extractor": share_features_extractor, + } + ) + + self.actor, self.actor_target = None, None + self.critic, self.critic_target = None, None + self.share_features_extractor = share_features_extractor + + self._build(lr_schedule) + + def _build(self, lr_schedule: Schedule) -> None: + self.actor = self.make_actor() + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + if self.share_features_extractor: + self.critic = self.make_critic(features_extractor=self.actor.features_extractor) + # Do not optimize the shared features extractor with the critic loss + # otherwise, there are gradient computation issues + critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + else: + # Create a separate features extractor for the critic + # this requires more memory and computation + self.critic = self.make_critic(features_extractor=None) + critic_parameters = self.critic.parameters() + + # Critic target should not share the features extractor with critic + self.critic_target = self.make_critic(features_extractor=None) + self.critic_target.load_state_dict(self.critic.state_dict()) + + self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + + # Target networks should always be in eval mode + self.critic_target.set_training_mode(False) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.net_args["activation_fn"], + use_sde=self.actor_kwargs["use_sde"], + log_std_init=self.actor_kwargs["log_std_init"], + use_expln=self.actor_kwargs["use_expln"], + clip_mean=self.actor_kwargs["clip_mean"], + n_critics=self.critic_kwargs["n_critics"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + + :param batch_size: + """ + self.actor.reset_noise(batch_size=batch_size) + + def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: + actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) + return Actor(**actor_kwargs).to(self.device) + + def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic: + critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) + return ContinuousCritic(**critic_kwargs).to(self.device) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self.actor(observation, deterministic) + + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.actor.set_training_mode(mode) + self.critic.set_training_mode(mode) + self.training = mode diff --git a/sbBrix/ppo/ppo.py b/sbBrix/ppo/ppo.py index a91044a..0181975 100644 --- a/sbBrix/ppo/ppo.py +++ b/sbBrix/ppo/ppo.py @@ -6,15 +6,17 @@ import torch as th from gym import spaces from torch.nn import functional as F -from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from ..common.on_policy_algorithm import BetterOnPolicyAlgorithm +# from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +from ..common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn SelfPPO = TypeVar("SelfPPO", bound="PPO") -class PPO(OnPolicyAlgorithm): +class PPO(BetterOnPolicyAlgorithm): """ Proximal Policy Optimization algorithm (PPO) (clip version) @@ -69,9 +71,7 @@ class PPO(OnPolicyAlgorithm): """ policy_aliases: Dict[str, Type[BasePolicy]] = { - "MlpPolicy": ActorCriticPolicy, - "CnnPolicy": ActorCriticCnnPolicy, - "MultiInputPolicy": MultiInputActorCriticPolicy, + "MlpPolicy": ActorCriticPolicy } def __init__( @@ -92,6 +92,7 @@ class PPO(OnPolicyAlgorithm): max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, + use_pca: bool = False, target_kl: Optional[float] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, @@ -113,6 +114,7 @@ class PPO(OnPolicyAlgorithm): max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, + use_pca=use_pca, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, @@ -315,4 +317,3 @@ class PPO(OnPolicyAlgorithm): reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) - diff --git a/sbBrix/sac/sac.py b/sbBrix/sac/sac.py index 760c0eb..b3ee34e 100644 --- a/sbBrix/sac/sac.py +++ b/sbBrix/sac/sac.py @@ -7,7 +7,8 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +# from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from common.off_policy_algorithm import BetterOffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update @@ -16,7 +17,7 @@ from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolic SelfSAC = TypeVar("SelfSAC", bound="SAC") -class SAC(OffPolicyAlgorithm): +class SAC(BetterOffPolicyAlgorithm): """ Soft Actor-Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, @@ -321,4 +322,3 @@ class SAC(OffPolicyAlgorithm): else: saved_pytorch_variables = ["ent_coef_tensor"] return state_dicts, saved_pytorch_variables -