Implement Importance Sampling for PCA
This commit is contained in:
parent
1fa66611a3
commit
0806a24036
38
metastable_baselines2/README.md
Normal file
38
metastable_baselines2/README.md
Normal file
@ -0,0 +1,38 @@
|
||||
# Metastable Baselines 2
|
||||
|
||||
<p align='center'>
|
||||
<img src='./icon.svg'>
|
||||
</p>
|
||||
|
||||
An extension to Stable Baselines 3. Based on Metastable Baselines 1.
|
||||
|
||||
During training of a RL-Agent we follow the gradient of the loss, which leads us to a minimum. In cases where the found minimum is merely a local minimum, this can be seen as a _false vacuum_ in our loss space. Exploration mechanisms try to let our training procedure escape these _stable states_: Making them _metastable_.
|
||||
|
||||
In order to archive this, this Repo contains some extensions for [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3)
|
||||
These extensions include:
|
||||
|
||||
- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207)
|
||||
- Support for Prior Conditioned Annealing
|
||||
- Support for Contextual Covariances (Planned)
|
||||
- Support for Full Covariances (Planned)
|
||||
|
||||
The resulting algorithms can than be tested for their ability of exploration in the enviroments provided by [Fancy Gym](https://github.com/ALRhub/fancy_gym) or [Project Columbus](https://git.dominik-roth.eu/dodox/Columbus)
|
||||
|
||||
## Installation
|
||||
|
||||
#### Install dependency: Metastable Projections
|
||||
|
||||
Follow instructions for the [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections) ([GitHub Mirror](https://github.com/D-o-d-o-x/metastable-projections)).
|
||||
KL Projections require ALR's ITPAL as an additional dependecy.
|
||||
|
||||
#### Install as a package
|
||||
|
||||
Then install this repo as a package:
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE).
|
@ -95,6 +95,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_pca: bool = False,
|
||||
pca_is: bool = False,
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
@ -119,6 +120,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
use_pca=use_pca,
|
||||
pca_is=pca_is,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
@ -217,7 +219,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
if self.use_sde or self.use_pca:
|
||||
self.policy.reset_noise(self.batch_size)
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||
values, log_prob, entropy, _ = self.policy.evaluate_actions(rollout_data, actions)
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
@ -226,6 +228,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
# With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
|
||||
# clipped surrogate loss
|
||||
|
@ -12,11 +12,16 @@ from ..common.policies import ActorCriticPolicy, BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
|
||||
#from metastable_baselines2 import PPO
|
||||
from metastable_projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||
|
||||
SelfTRPL = TypeVar("SelfTRPL", bound="TRPL")
|
||||
|
||||
|
||||
def castProjection(proj):
|
||||
if type(proj)==str:
|
||||
return getattr(metastable_projections, proj + 'ProjectionLayer')
|
||||
return proj
|
||||
|
||||
class TRPL(BetterOnPolicyAlgorithm):
|
||||
"""
|
||||
TODO: Bla
|
||||
@ -90,6 +95,8 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_pca: bool = False,
|
||||
pca_is: bool = False,
|
||||
projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(),
|
||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_kl: Optional[float] = None,
|
||||
@ -114,6 +121,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq,
|
||||
use_pca=use_pca,
|
||||
pca_is=pca_is,
|
||||
rollout_buffer_class=rollout_buffer_class,
|
||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||
stats_window_size=stats_window_size,
|
||||
@ -163,6 +171,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
self.clip_range = clip_range
|
||||
self.clip_range_vf = clip_range_vf
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.projection = castProjection(projection)
|
||||
self.target_kl = target_kl
|
||||
|
||||
if _init_setup_model:
|
||||
@ -193,8 +202,11 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
|
||||
trust_region_losses = []
|
||||
entropy_losses = []
|
||||
pg_losses, value_losses = [], []
|
||||
pg_losses = []
|
||||
value_losses = []
|
||||
policy_losses = []
|
||||
clip_fractions = []
|
||||
|
||||
continue_training = True
|
||||
@ -212,7 +224,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
if self.use_sde or self.use_pca:
|
||||
self.policy.reset_noise(self.batch_size)
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||
values, log_prob, entropy, trust_region_loss = self.policy.evaluate_actions(rollout_data, actions)
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
@ -221,15 +233,19 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
# With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
|
||||
# clipped surrogate loss
|
||||
policy_loss_1 = advantages * ratio
|
||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
if self.clip_range is None:
|
||||
surrogate_loss = -(advantages * ratio).mean()
|
||||
else:
|
||||
surrogate_loss_1 = advantages * ratio
|
||||
surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
||||
|
||||
# Logging
|
||||
pg_losses.append(policy_loss.item())
|
||||
pg_losses.append(surrogate_loss.item())
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
|
||||
@ -253,10 +269,13 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
policy_loss = trust_region_loss + surrogate_loss
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
trust_region_losses.append(trust_region_loss.item())
|
||||
policy_losses.append(policy_loss.item())
|
||||
|
||||
# Calculate approximate form of reverse KL Divergence for early stopping
|
||||
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
||||
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
||||
@ -288,6 +307,8 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
# Logs
|
||||
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||
self.logger.record("train/trust_region_loss", np.mean(trust_region_losses))
|
||||
self.logger.record("train/policy_loss", np.mean(policy_losses))
|
||||
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
||||
@ -302,14 +323,14 @@ class TRPL(BetterOnPolicyAlgorithm):
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self: SelfPPO,
|
||||
self: SelfTRPL,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
tb_log_name: str = "PPO",
|
||||
tb_log_name: str = "TRPL",
|
||||
reset_num_timesteps: bool = True,
|
||||
progress_bar: bool = False,
|
||||
) -> SelfPPO:
|
||||
) -> SelfTRPL:
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
|
Loading…
Reference in New Issue
Block a user