Implement Importance Sampling for PCA

This commit is contained in:
Dominik Moritz Roth 2024-01-16 15:13:06 +01:00
parent 1fa66611a3
commit 0806a24036
3 changed files with 75 additions and 13 deletions

View File

@ -0,0 +1,38 @@
# Metastable Baselines 2
<p align='center'>
<img src='./icon.svg'>
</p>
An extension to Stable Baselines 3. Based on Metastable Baselines 1.
During training of a RL-Agent we follow the gradient of the loss, which leads us to a minimum. In cases where the found minimum is merely a local minimum, this can be seen as a _false vacuum_ in our loss space. Exploration mechanisms try to let our training procedure escape these _stable states_: Making them _metastable_.
In order to archive this, this Repo contains some extensions for [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3)
These extensions include:
- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207)
- Support for Prior Conditioned Annealing
- Support for Contextual Covariances (Planned)
- Support for Full Covariances (Planned)
The resulting algorithms can than be tested for their ability of exploration in the enviroments provided by [Fancy Gym](https://github.com/ALRhub/fancy_gym) or [Project Columbus](https://git.dominik-roth.eu/dodox/Columbus)
## Installation
#### Install dependency: Metastable Projections
Follow instructions for the [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections) ([GitHub Mirror](https://github.com/D-o-d-o-x/metastable-projections)).
KL Projections require ALR's ITPAL as an additional dependecy.
#### Install as a package
Then install this repo as a package:
```
pip install -e .
```
## License
Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE).

View File

@ -95,6 +95,7 @@ class PPO(BetterOnPolicyAlgorithm):
use_sde: bool = False,
sde_sample_freq: int = -1,
use_pca: bool = False,
pca_is: bool = False,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
@ -119,6 +120,7 @@ class PPO(BetterOnPolicyAlgorithm):
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_pca=use_pca,
pca_is=pca_is,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
@ -217,7 +219,7 @@ class PPO(BetterOnPolicyAlgorithm):
if self.use_sde or self.use_pca:
self.policy.reset_noise(self.batch_size)
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values, log_prob, entropy, _ = self.policy.evaluate_actions(rollout_data, actions)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
@ -226,6 +228,7 @@ class PPO(BetterOnPolicyAlgorithm):
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
# With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss

View File

@ -12,11 +12,16 @@ from ..common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
#from metastable_baselines2 import PPO
from metastable_projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
SelfTRPL = TypeVar("SelfTRPL", bound="TRPL")
def castProjection(proj):
if type(proj)==str:
return getattr(metastable_projections, proj + 'ProjectionLayer')
return proj
class TRPL(BetterOnPolicyAlgorithm):
"""
TODO: Bla
@ -90,6 +95,8 @@ class TRPL(BetterOnPolicyAlgorithm):
use_sde: bool = False,
sde_sample_freq: int = -1,
use_pca: bool = False,
pca_is: bool = False,
projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(),
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
@ -114,6 +121,7 @@ class TRPL(BetterOnPolicyAlgorithm):
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_pca=use_pca,
pca_is=pca_is,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
@ -163,6 +171,7 @@ class TRPL(BetterOnPolicyAlgorithm):
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.projection = castProjection(projection)
self.target_kl = target_kl
if _init_setup_model:
@ -193,8 +202,11 @@ class TRPL(BetterOnPolicyAlgorithm):
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
trust_region_losses = []
entropy_losses = []
pg_losses, value_losses = [], []
pg_losses = []
value_losses = []
policy_losses = []
clip_fractions = []
continue_training = True
@ -212,7 +224,7 @@ class TRPL(BetterOnPolicyAlgorithm):
if self.use_sde or self.use_pca:
self.policy.reset_noise(self.batch_size)
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values, log_prob, entropy, trust_region_loss = self.policy.evaluate_actions(rollout_data, actions)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
@ -221,15 +233,19 @@ class TRPL(BetterOnPolicyAlgorithm):
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
# With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
if self.clip_range is None:
surrogate_loss = -(advantages * ratio).mean()
else:
surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean()
# Logging
pg_losses.append(policy_loss.item())
pg_losses.append(surrogate_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
@ -253,10 +269,13 @@ class TRPL(BetterOnPolicyAlgorithm):
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
policy_loss = trust_region_loss + surrogate_loss
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
entropy_losses.append(entropy_loss.item())
trust_region_losses.append(trust_region_loss.item())
policy_losses.append(policy_loss.item())
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
@ -288,6 +307,8 @@ class TRPL(BetterOnPolicyAlgorithm):
# Logs
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/trust_region_loss", np.mean(trust_region_losses))
self.logger.record("train/policy_loss", np.mean(policy_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
@ -302,14 +323,14 @@ class TRPL(BetterOnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self: SelfPPO,
self: SelfTRPL,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "PPO",
tb_log_name: str = "TRPL",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfPPO:
) -> SelfTRPL:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,