diff --git a/metastable_baselines2/README.md b/metastable_baselines2/README.md
new file mode 100644
index 0000000..a211610
--- /dev/null
+++ b/metastable_baselines2/README.md
@@ -0,0 +1,38 @@
+# Metastable Baselines 2
+
+
+
+
+
+An extension to Stable Baselines 3. Based on Metastable Baselines 1.
+
+During training of a RL-Agent we follow the gradient of the loss, which leads us to a minimum. In cases where the found minimum is merely a local minimum, this can be seen as a _false vacuum_ in our loss space. Exploration mechanisms try to let our training procedure escape these _stable states_: Making them _metastable_.
+
+In order to archive this, this Repo contains some extensions for [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3)
+These extensions include:
+
+- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207)
+- Support for Prior Conditioned Annealing
+- Support for Contextual Covariances (Planned)
+- Support for Full Covariances (Planned)
+
+The resulting algorithms can than be tested for their ability of exploration in the enviroments provided by [Fancy Gym](https://github.com/ALRhub/fancy_gym) or [Project Columbus](https://git.dominik-roth.eu/dodox/Columbus)
+
+## Installation
+
+#### Install dependency: Metastable Projections
+
+Follow instructions for the [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections) ([GitHub Mirror](https://github.com/D-o-d-o-x/metastable-projections)).
+KL Projections require ALR's ITPAL as an additional dependecy.
+
+#### Install as a package
+
+Then install this repo as a package:
+
+```
+pip install -e .
+```
+
+## License
+
+Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE).
diff --git a/metastable_baselines2/ppo/ppo.py b/metastable_baselines2/ppo/ppo.py
index 732f945..f9c59b5 100644
--- a/metastable_baselines2/ppo/ppo.py
+++ b/metastable_baselines2/ppo/ppo.py
@@ -95,6 +95,7 @@ class PPO(BetterOnPolicyAlgorithm):
use_sde: bool = False,
sde_sample_freq: int = -1,
use_pca: bool = False,
+ pca_is: bool = False,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
@@ -119,6 +120,7 @@ class PPO(BetterOnPolicyAlgorithm):
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_pca=use_pca,
+ pca_is=pca_is,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
@@ -217,7 +219,7 @@ class PPO(BetterOnPolicyAlgorithm):
if self.use_sde or self.use_pca:
self.policy.reset_noise(self.batch_size)
- values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
+ values, log_prob, entropy, _ = self.policy.evaluate_actions(rollout_data, actions)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
@@ -226,6 +228,7 @@ class PPO(BetterOnPolicyAlgorithm):
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
+ # With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py
index 60af830..7f91a87 100644
--- a/metastable_baselines2/trpl/trpl.py
+++ b/metastable_baselines2/trpl/trpl.py
@@ -12,11 +12,16 @@ from ..common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
-#from metastable_baselines2 import PPO
+from metastable_projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
SelfTRPL = TypeVar("SelfTRPL", bound="TRPL")
+def castProjection(proj):
+ if type(proj)==str:
+ return getattr(metastable_projections, proj + 'ProjectionLayer')
+ return proj
+
class TRPL(BetterOnPolicyAlgorithm):
"""
TODO: Bla
@@ -90,6 +95,8 @@ class TRPL(BetterOnPolicyAlgorithm):
use_sde: bool = False,
sde_sample_freq: int = -1,
use_pca: bool = False,
+ pca_is: bool = False,
+ projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(),
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
@@ -114,6 +121,7 @@ class TRPL(BetterOnPolicyAlgorithm):
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_pca=use_pca,
+ pca_is=pca_is,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
@@ -163,6 +171,7 @@ class TRPL(BetterOnPolicyAlgorithm):
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
+ self.projection = castProjection(projection)
self.target_kl = target_kl
if _init_setup_model:
@@ -193,8 +202,11 @@ class TRPL(BetterOnPolicyAlgorithm):
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
+ trust_region_losses = []
entropy_losses = []
- pg_losses, value_losses = [], []
+ pg_losses = []
+ value_losses = []
+ policy_losses = []
clip_fractions = []
continue_training = True
@@ -212,7 +224,7 @@ class TRPL(BetterOnPolicyAlgorithm):
if self.use_sde or self.use_pca:
self.policy.reset_noise(self.batch_size)
- values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
+ values, log_prob, entropy, trust_region_loss = self.policy.evaluate_actions(rollout_data, actions)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
@@ -221,15 +233,19 @@ class TRPL(BetterOnPolicyAlgorithm):
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
+ # With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
- policy_loss_1 = advantages * ratio
- policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
- policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
+ if self.clip_range is None:
+ surrogate_loss = -(advantages * ratio).mean()
+ else:
+ surrogate_loss_1 = advantages * ratio
+ surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
+ surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean()
# Logging
- pg_losses.append(policy_loss.item())
+ pg_losses.append(surrogate_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
@@ -253,10 +269,13 @@ class TRPL(BetterOnPolicyAlgorithm):
else:
entropy_loss = -th.mean(entropy)
- entropy_losses.append(entropy_loss.item())
-
+ policy_loss = trust_region_loss + surrogate_loss
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
+ entropy_losses.append(entropy_loss.item())
+ trust_region_losses.append(trust_region_loss.item())
+ policy_losses.append(policy_loss.item())
+
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
@@ -288,6 +307,8 @@ class TRPL(BetterOnPolicyAlgorithm):
# Logs
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
+ self.logger.record("train/trust_region_loss", np.mean(trust_region_losses))
+ self.logger.record("train/policy_loss", np.mean(policy_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
@@ -302,14 +323,14 @@ class TRPL(BetterOnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
- self: SelfPPO,
+ self: SelfTRPL,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
- tb_log_name: str = "PPO",
+ tb_log_name: str = "TRPL",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
- ) -> SelfPPO:
+ ) -> SelfTRPL:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,