metastable-baselines2/README.md

92 lines
3.5 KiB
Markdown
Raw Normal View History

2024-01-16 15:13:06 +01:00
# Metastable Baselines 2
<p align='center'>
<img src='./icon.svg'>
</p>
An extension to Stable Baselines 3. Based on Metastable Baselines 1.
2024-03-14 17:35:07 +01:00
This repo provides:
2024-01-16 15:13:06 +01:00
- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207)
2024-04-03 17:53:56 +02:00
- Support for Contextual Covariances
- Support for Full Covariances
2024-01-16 15:13:06 +01:00
## Installation
#### Install dependency: Metastable Projections
Follow instructions for the [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections) ([GitHub Mirror](https://github.com/D-o-d-o-x/metastable-projections)).
KL Projections require ALR's ITPAL as an additional dependecy.
#### Install as a package
Then install this repo as a package:
2024-03-30 14:41:43 +01:00
```bash
2024-01-16 15:13:06 +01:00
pip install -e .
```
2024-04-01 00:18:00 +02:00
If you want to be able to use full / contextual covariances, install with the optional dependency 'pca':
```bash
pip install -e '.[pca]'
```
2024-03-14 17:35:07 +01:00
## Usage
2024-03-30 14:41:43 +01:00
### TRPL
2024-03-14 17:35:07 +01:00
TRPL can be used just like SB3's PPO:
2024-03-30 14:41:43 +01:00
```python
2024-03-14 17:35:07 +01:00
import gymnasium as gym
from metastable_baselines2 import TRPL
2024-03-30 14:41:43 +01:00
env_id = 'LunarLanderContinuous-v2'
2024-03-14 17:35:07 +01:00
projection = 'Wasserstein' # or Frobenius or KL
2024-03-30 14:41:43 +01:00
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, verbose=1)
2024-03-14 17:35:07 +01:00
2024-04-03 17:53:56 +02:00
model.learn(total_timesteps=256)
2024-03-14 17:35:07 +01:00
```
2024-03-30 14:41:43 +01:00
Configure TRPL py passing `projection_kwargs` to TRPL:
```python
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1)
```
2024-04-03 17:53:56 +02:00
For available projection_kwargs have a look at [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections).
2024-03-14 17:35:07 +01:00
2024-03-30 14:41:43 +01:00
### Full Covariance
SB3 does not support full covariances (only diagonal). We still provide support for full covariances via the seperate PCA package. (But since we don't actually want to use PCA ('Prior Conditioned Annealing'), we pass 'skip_conditioning=True'; this will lead to the underlying Noise being used directly.)
We therefore pass `use_pca=True` and `policy_kwargs.dist_kwargs = {'Base_Noise': 'WHITE', par_strength: 'FULL', skip_conditioning=True}`
```python
# We support PPO and TRPL, (SAC is untested, we are open to PRs fixing issues)
2024-04-03 17:53:56 +02:00
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': 'FULL', 'skip_conditioning': True}), projection_class=projection, verbose=1)
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
model.learn(total_timesteps=256)
2024-03-30 14:41:43 +01:00
```
2024-04-03 17:53:56 +02:00
The supported values for `par_strength` are:
- `SCALAR`: We only learn a single scalar value, that is used along the whole diagonal. No covariance is modeled.
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `DIAG`: We learn a diagonal covariance matrix. (e.g. only variances).
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `FULL`: We learn a full covariance matrix, induced via Cholesky decomp (except when Wasserstein Projection is used; then we use the Cholesky of the SPD matrix sqrt of the covariance marix).
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_SCALAR`: Same as `SCALAR`, but the scalar is not global, it is parameterized by the policy net (contextual).
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_DIAG`: Same as `DIAG`, but the values are not global, they are parameterized by the policy net.
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_HYBRID`: We learn a parameric diagonal, that is scaled by the policy net.
2024-03-30 14:41:43 +01:00
2024-04-03 17:53:56 +02:00
- `CONT_FULL`: Same as `FULL`, but parameterized by the policy net.
2024-03-30 14:41:43 +01:00
2024-01-16 15:13:06 +01:00
## License
2024-04-03 17:53:56 +02:00
Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE), and so are our extensions.