From 0806a24036917419f2b50ebeee4fe2e62308a74f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 16 Jan 2024 15:13:06 +0100 Subject: [PATCH] Implement Importance Sampling for PCA --- metastable_baselines2/README.md | 38 +++++++++++++++++++++++++ metastable_baselines2/ppo/ppo.py | 5 +++- metastable_baselines2/trpl/trpl.py | 45 ++++++++++++++++++++++-------- 3 files changed, 75 insertions(+), 13 deletions(-) create mode 100644 metastable_baselines2/README.md diff --git a/metastable_baselines2/README.md b/metastable_baselines2/README.md new file mode 100644 index 0000000..a211610 --- /dev/null +++ b/metastable_baselines2/README.md @@ -0,0 +1,38 @@ +# Metastable Baselines 2 + +

+ +

+ +An extension to Stable Baselines 3. Based on Metastable Baselines 1. + +During training of a RL-Agent we follow the gradient of the loss, which leads us to a minimum. In cases where the found minimum is merely a local minimum, this can be seen as a _false vacuum_ in our loss space. Exploration mechanisms try to let our training procedure escape these _stable states_: Making them _metastable_. + +In order to archive this, this Repo contains some extensions for [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3) +These extensions include: + +- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207) +- Support for Prior Conditioned Annealing +- Support for Contextual Covariances (Planned) +- Support for Full Covariances (Planned) + +The resulting algorithms can than be tested for their ability of exploration in the enviroments provided by [Fancy Gym](https://github.com/ALRhub/fancy_gym) or [Project Columbus](https://git.dominik-roth.eu/dodox/Columbus) + +## Installation + +#### Install dependency: Metastable Projections + +Follow instructions for the [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections) ([GitHub Mirror](https://github.com/D-o-d-o-x/metastable-projections)). +KL Projections require ALR's ITPAL as an additional dependecy. + +#### Install as a package + +Then install this repo as a package: + +``` +pip install -e . +``` + +## License + +Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE). diff --git a/metastable_baselines2/ppo/ppo.py b/metastable_baselines2/ppo/ppo.py index 732f945..f9c59b5 100644 --- a/metastable_baselines2/ppo/ppo.py +++ b/metastable_baselines2/ppo/ppo.py @@ -95,6 +95,7 @@ class PPO(BetterOnPolicyAlgorithm): use_sde: bool = False, sde_sample_freq: int = -1, use_pca: bool = False, + pca_is: bool = False, rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, target_kl: Optional[float] = None, @@ -119,6 +120,7 @@ class PPO(BetterOnPolicyAlgorithm): use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_pca=use_pca, + pca_is=pca_is, rollout_buffer_class=rollout_buffer_class, rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, @@ -217,7 +219,7 @@ class PPO(BetterOnPolicyAlgorithm): if self.use_sde or self.use_pca: self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) + values, log_prob, entropy, _ = self.policy.evaluate_actions(rollout_data, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages @@ -226,6 +228,7 @@ class PPO(BetterOnPolicyAlgorithm): advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration + # With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one) ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 60af830..7f91a87 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -12,11 +12,16 @@ from ..common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn -#from metastable_baselines2 import PPO +from metastable_projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer SelfTRPL = TypeVar("SelfTRPL", bound="TRPL") +def castProjection(proj): + if type(proj)==str: + return getattr(metastable_projections, proj + 'ProjectionLayer') + return proj + class TRPL(BetterOnPolicyAlgorithm): """ TODO: Bla @@ -90,6 +95,8 @@ class TRPL(BetterOnPolicyAlgorithm): use_sde: bool = False, sde_sample_freq: int = -1, use_pca: bool = False, + pca_is: bool = False, + projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(), rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, target_kl: Optional[float] = None, @@ -114,6 +121,7 @@ class TRPL(BetterOnPolicyAlgorithm): use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_pca=use_pca, + pca_is=pca_is, rollout_buffer_class=rollout_buffer_class, rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, @@ -163,6 +171,7 @@ class TRPL(BetterOnPolicyAlgorithm): self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage + self.projection = castProjection(projection) self.target_kl = target_kl if _init_setup_model: @@ -193,8 +202,11 @@ class TRPL(BetterOnPolicyAlgorithm): if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + trust_region_losses = [] entropy_losses = [] - pg_losses, value_losses = [], [] + pg_losses = [] + value_losses = [] + policy_losses = [] clip_fractions = [] continue_training = True @@ -212,7 +224,7 @@ class TRPL(BetterOnPolicyAlgorithm): if self.use_sde or self.use_pca: self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) + values, log_prob, entropy, trust_region_loss = self.policy.evaluate_actions(rollout_data, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages @@ -221,15 +233,19 @@ class TRPL(BetterOnPolicyAlgorithm): advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration + # With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one) ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss - policy_loss_1 = advantages * ratio - policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + if self.clip_range is None: + surrogate_loss = -(advantages * ratio).mean() + else: + surrogate_loss_1 = advantages * ratio + surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean() # Logging - pg_losses.append(policy_loss.item()) + pg_losses.append(surrogate_loss.item()) clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) @@ -253,10 +269,13 @@ class TRPL(BetterOnPolicyAlgorithm): else: entropy_loss = -th.mean(entropy) - entropy_losses.append(entropy_loss.item()) - + policy_loss = trust_region_loss + surrogate_loss loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + entropy_losses.append(entropy_loss.item()) + trust_region_losses.append(trust_region_loss.item()) + policy_losses.append(policy_loss.item()) + # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 @@ -288,6 +307,8 @@ class TRPL(BetterOnPolicyAlgorithm): # Logs self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/trust_region_loss", np.mean(trust_region_losses)) + self.logger.record("train/policy_loss", np.mean(policy_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) @@ -302,14 +323,14 @@ class TRPL(BetterOnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: SelfPPO, + self: SelfTRPL, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, - tb_log_name: str = "PPO", + tb_log_name: str = "TRPL", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfPPO: + ) -> SelfTRPL: return super().learn( total_timesteps=total_timesteps, callback=callback,