Implement Importance Sampling for PCA
This commit is contained in:
parent
1fa66611a3
commit
0806a24036
38
metastable_baselines2/README.md
Normal file
38
metastable_baselines2/README.md
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Metastable Baselines 2
|
||||||
|
|
||||||
|
<p align='center'>
|
||||||
|
<img src='./icon.svg'>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
An extension to Stable Baselines 3. Based on Metastable Baselines 1.
|
||||||
|
|
||||||
|
During training of a RL-Agent we follow the gradient of the loss, which leads us to a minimum. In cases where the found minimum is merely a local minimum, this can be seen as a _false vacuum_ in our loss space. Exploration mechanisms try to let our training procedure escape these _stable states_: Making them _metastable_.
|
||||||
|
|
||||||
|
In order to archive this, this Repo contains some extensions for [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3)
|
||||||
|
These extensions include:
|
||||||
|
|
||||||
|
- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207)
|
||||||
|
- Support for Prior Conditioned Annealing
|
||||||
|
- Support for Contextual Covariances (Planned)
|
||||||
|
- Support for Full Covariances (Planned)
|
||||||
|
|
||||||
|
The resulting algorithms can than be tested for their ability of exploration in the enviroments provided by [Fancy Gym](https://github.com/ALRhub/fancy_gym) or [Project Columbus](https://git.dominik-roth.eu/dodox/Columbus)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
#### Install dependency: Metastable Projections
|
||||||
|
|
||||||
|
Follow instructions for the [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections) ([GitHub Mirror](https://github.com/D-o-d-o-x/metastable-projections)).
|
||||||
|
KL Projections require ALR's ITPAL as an additional dependecy.
|
||||||
|
|
||||||
|
#### Install as a package
|
||||||
|
|
||||||
|
Then install this repo as a package:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE).
|
@ -95,6 +95,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
use_pca: bool = False,
|
use_pca: bool = False,
|
||||||
|
pca_is: bool = False,
|
||||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
@ -119,6 +120,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
use_pca=use_pca,
|
use_pca=use_pca,
|
||||||
|
pca_is=pca_is,
|
||||||
rollout_buffer_class=rollout_buffer_class,
|
rollout_buffer_class=rollout_buffer_class,
|
||||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||||
stats_window_size=stats_window_size,
|
stats_window_size=stats_window_size,
|
||||||
@ -217,7 +219,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
if self.use_sde or self.use_pca:
|
if self.use_sde or self.use_pca:
|
||||||
self.policy.reset_noise(self.batch_size)
|
self.policy.reset_noise(self.batch_size)
|
||||||
|
|
||||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
values, log_prob, entropy, _ = self.policy.evaluate_actions(rollout_data, actions)
|
||||||
values = values.flatten()
|
values = values.flatten()
|
||||||
# Normalize advantage
|
# Normalize advantage
|
||||||
advantages = rollout_data.advantages
|
advantages = rollout_data.advantages
|
||||||
@ -226,6 +228,7 @@ class PPO(BetterOnPolicyAlgorithm):
|
|||||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||||
|
|
||||||
# ratio between old and new policy, should be one at the first iteration
|
# ratio between old and new policy, should be one at the first iteration
|
||||||
|
# With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
|
||||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||||
|
|
||||||
# clipped surrogate loss
|
# clipped surrogate loss
|
||||||
|
@ -12,11 +12,16 @@ from ..common.policies import ActorCriticPolicy, BasePolicy
|
|||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||||
|
|
||||||
#from metastable_baselines2 import PPO
|
from metastable_projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||||
|
|
||||||
SelfTRPL = TypeVar("SelfTRPL", bound="TRPL")
|
SelfTRPL = TypeVar("SelfTRPL", bound="TRPL")
|
||||||
|
|
||||||
|
|
||||||
|
def castProjection(proj):
|
||||||
|
if type(proj)==str:
|
||||||
|
return getattr(metastable_projections, proj + 'ProjectionLayer')
|
||||||
|
return proj
|
||||||
|
|
||||||
class TRPL(BetterOnPolicyAlgorithm):
|
class TRPL(BetterOnPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
TODO: Bla
|
TODO: Bla
|
||||||
@ -90,6 +95,8 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
use_pca: bool = False,
|
use_pca: bool = False,
|
||||||
|
pca_is: bool = False,
|
||||||
|
projection: Union[BaseProjectionLayer, str] = BaseProjectionLayer(),
|
||||||
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
@ -114,6 +121,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
use_pca=use_pca,
|
use_pca=use_pca,
|
||||||
|
pca_is=pca_is,
|
||||||
rollout_buffer_class=rollout_buffer_class,
|
rollout_buffer_class=rollout_buffer_class,
|
||||||
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
rollout_buffer_kwargs=rollout_buffer_kwargs,
|
||||||
stats_window_size=stats_window_size,
|
stats_window_size=stats_window_size,
|
||||||
@ -163,6 +171,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
self.clip_range = clip_range
|
self.clip_range = clip_range
|
||||||
self.clip_range_vf = clip_range_vf
|
self.clip_range_vf = clip_range_vf
|
||||||
self.normalize_advantage = normalize_advantage
|
self.normalize_advantage = normalize_advantage
|
||||||
|
self.projection = castProjection(projection)
|
||||||
self.target_kl = target_kl
|
self.target_kl = target_kl
|
||||||
|
|
||||||
if _init_setup_model:
|
if _init_setup_model:
|
||||||
@ -193,8 +202,11 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||||
|
|
||||||
|
trust_region_losses = []
|
||||||
entropy_losses = []
|
entropy_losses = []
|
||||||
pg_losses, value_losses = [], []
|
pg_losses = []
|
||||||
|
value_losses = []
|
||||||
|
policy_losses = []
|
||||||
clip_fractions = []
|
clip_fractions = []
|
||||||
|
|
||||||
continue_training = True
|
continue_training = True
|
||||||
@ -212,7 +224,7 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
if self.use_sde or self.use_pca:
|
if self.use_sde or self.use_pca:
|
||||||
self.policy.reset_noise(self.batch_size)
|
self.policy.reset_noise(self.batch_size)
|
||||||
|
|
||||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
values, log_prob, entropy, trust_region_loss = self.policy.evaluate_actions(rollout_data, actions)
|
||||||
values = values.flatten()
|
values = values.flatten()
|
||||||
# Normalize advantage
|
# Normalize advantage
|
||||||
advantages = rollout_data.advantages
|
advantages = rollout_data.advantages
|
||||||
@ -221,15 +233,19 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||||
|
|
||||||
# ratio between old and new policy, should be one at the first iteration
|
# ratio between old and new policy, should be one at the first iteration
|
||||||
|
# With pca_is=True, old_log_prob will be of the conditioned old dist (doing two Importance Sampling in one)
|
||||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||||
|
|
||||||
# clipped surrogate loss
|
# clipped surrogate loss
|
||||||
policy_loss_1 = advantages * ratio
|
if self.clip_range is None:
|
||||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
surrogate_loss = -(advantages * ratio).mean()
|
||||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
else:
|
||||||
|
surrogate_loss_1 = advantages * ratio
|
||||||
|
surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||||
|
surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
pg_losses.append(policy_loss.item())
|
pg_losses.append(surrogate_loss.item())
|
||||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||||
clip_fractions.append(clip_fraction)
|
clip_fractions.append(clip_fraction)
|
||||||
|
|
||||||
@ -253,10 +269,13 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
else:
|
else:
|
||||||
entropy_loss = -th.mean(entropy)
|
entropy_loss = -th.mean(entropy)
|
||||||
|
|
||||||
entropy_losses.append(entropy_loss.item())
|
policy_loss = trust_region_loss + surrogate_loss
|
||||||
|
|
||||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||||
|
|
||||||
|
entropy_losses.append(entropy_loss.item())
|
||||||
|
trust_region_losses.append(trust_region_loss.item())
|
||||||
|
policy_losses.append(policy_loss.item())
|
||||||
|
|
||||||
# Calculate approximate form of reverse KL Divergence for early stopping
|
# Calculate approximate form of reverse KL Divergence for early stopping
|
||||||
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
|
||||||
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
|
||||||
@ -288,6 +307,8 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
# Logs
|
# Logs
|
||||||
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||||
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||||
|
self.logger.record("train/trust_region_loss", np.mean(trust_region_losses))
|
||||||
|
self.logger.record("train/policy_loss", np.mean(policy_losses))
|
||||||
self.logger.record("train/value_loss", np.mean(value_losses))
|
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||||
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||||
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
||||||
@ -302,14 +323,14 @@ class TRPL(BetterOnPolicyAlgorithm):
|
|||||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
|
||||||
def learn(
|
def learn(
|
||||||
self: SelfPPO,
|
self: SelfTRPL,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
callback: MaybeCallback = None,
|
callback: MaybeCallback = None,
|
||||||
log_interval: int = 1,
|
log_interval: int = 1,
|
||||||
tb_log_name: str = "PPO",
|
tb_log_name: str = "TRPL",
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
) -> SelfPPO:
|
) -> SelfTRPL:
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
Loading…
Reference in New Issue
Block a user