metastable-baselines2/README.md

92 lines
3.6 KiB
Markdown
Raw Normal View History

2024-01-16 15:13:06 +01:00
# Metastable Baselines 2
<p align='center'>
<img src='./icon.svg'>
</p>
An extension to Stable Baselines 3. Based on Metastable Baselines 1.
2024-03-14 17:35:07 +01:00
This repo provides:
2024-01-16 15:13:06 +01:00
- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207)
2024-04-03 17:53:56 +02:00
- Support for Contextual Covariances
- Support for Full Covariances
2024-01-16 15:13:06 +01:00
## Installation
#### Install dependency: Metastable Projections
Follow instructions for the [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections) ([GitHub Mirror](https://github.com/D-o-d-o-x/metastable-projections)).
KL Projections require ALR's ITPAL as an additional dependecy.
#### Install as a package
Then install this repo as a package:
2024-03-30 14:41:43 +01:00
```bash
2024-01-16 15:13:06 +01:00
pip install -e .
```
2024-04-01 00:18:00 +02:00
If you want to be able to use full / contextual covariances, install with the optional dependency 'pca':
```bash
pip install -e '.[pca]'
```
2024-03-14 17:35:07 +01:00
## Usage
2024-03-30 14:41:43 +01:00
### TRPL
2024-03-14 17:35:07 +01:00
TRPL can be used just like SB3's PPO:
2024-03-30 14:41:43 +01:00
```python
2024-03-14 17:35:07 +01:00
import gymnasium as gym
from metastable_baselines2 import TRPL
2024-03-30 14:41:43 +01:00
env_id = 'LunarLanderContinuous-v2'
2024-03-14 17:35:07 +01:00
projection = 'Wasserstein' # or Frobenius or KL
2024-03-30 14:41:43 +01:00
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, verbose=1)
2024-03-14 17:35:07 +01:00
2024-04-03 17:53:56 +02:00
model.learn(total_timesteps=256)
2024-03-14 17:35:07 +01:00
```
2024-03-30 14:41:43 +01:00
Configure TRPL py passing `projection_kwargs` to TRPL:
```python
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1)
```
2024-04-03 17:53:56 +02:00
For available projection_kwargs have a look at [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections).
2024-03-14 17:35:07 +01:00
2024-03-30 14:41:43 +01:00
### Full Covariance
2024-04-03 18:01:32 +02:00
SB3 does not support full covariances (only diagonal). We still provide support for full covariances via the seperate [PCA](https://git.dominik-roth.eu/dodox/PriorConditionedAnnealing) package. (But since we don't actually want to use PCA ('Prior Conditioned Annealing'), we pass 'skip_conditioning=True'; this will lead to the underlying Noise being used directly.)
2024-03-30 14:41:43 +01:00
We therefore pass `use_pca=True` and `policy_kwargs.dist_kwargs = {'Base_Noise': 'WHITE', par_strength: 'FULL', skip_conditioning=True}`
```python
# We support PPO and TRPL, (SAC is untested, we are open to PRs fixing issues)
2024-04-03 17:53:56 +02:00
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': 'FULL', 'skip_conditioning': True}), projection_class=projection, verbose=1)
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
model.learn(total_timesteps=256)
2024-03-30 14:41:43 +01:00
```
2024-04-03 17:53:56 +02:00
The supported values for `par_strength` are:
- `SCALAR`: We only learn a single scalar value, that is used along the whole diagonal. No covariance is modeled.
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `DIAG`: We learn a diagonal covariance matrix. (e.g. only variances).
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `FULL`: We learn a full covariance matrix, induced via Cholesky decomp (except when Wasserstein Projection is used; then we use the Cholesky of the SPD matrix sqrt of the covariance marix).
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_SCALAR`: Same as `SCALAR`, but the scalar is not global, it is parameterized by the policy net (contextual).
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_DIAG`: Same as `DIAG`, but the values are not global, they are parameterized by the policy net.
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_HYBRID`: We learn a parameric diagonal, that is scaled by the policy net.
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_FULL`: Same as `FULL`, but parameterized by the policy net.
2024-03-30 14:41:43 +01:00
2024-01-16 15:13:06 +01:00
## License
2024-04-03 17:53:56 +02:00
Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE), and so are our extensions.