Go to file
2024-04-03 18:31:06 +02:00
metastable_baselines2 Tell PCA to use mSqrt when using W2 proj 2024-03-30 14:44:20 +01:00
tests Wrote tests 2024-03-30 14:42:21 +01:00
.gitignore initial commit 2023-08-14 10:50:53 +02:00
icon.svg Added icon 2024-04-01 00:10:40 +02:00
README.md Link to PCA page from README 2024-04-03 18:01:32 +02:00
setup.py Update version number 2024-04-03 18:31:06 +02:00

Metastable Baselines 2

An extension to Stable Baselines 3. Based on Metastable Baselines 1.

This repo provides:

Installation

Install dependency: Metastable Projections

Follow instructions for the Metastable Projections (GitHub Mirror). KL Projections require ALR's ITPAL as an additional dependecy.

Install as a package

Then install this repo as a package:

pip install -e .

If you want to be able to use full / contextual covariances, install with the optional dependency 'pca':

pip install -e '.[pca]'

Usage

TRPL

TRPL can be used just like SB3's PPO:

import gymnasium as gym
from metastable_baselines2 import TRPL

env_id = 'LunarLanderContinuous-v2'
projection = 'Wasserstein' # or Frobenius or KL

model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, verbose=1)

model.learn(total_timesteps=256)

Configure TRPL py passing projection_kwargs to TRPL:

model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1)

For available projection_kwargs have a look at Metastable Projections.

Full Covariance

SB3 does not support full covariances (only diagonal). We still provide support for full covariances via the seperate PCA package. (But since we don't actually want to use PCA ('Prior Conditioned Annealing'), we pass 'skip_conditioning=True'; this will lead to the underlying Noise being used directly.)

We therefore pass use_pca=True and policy_kwargs.dist_kwargs = {'Base_Noise': 'WHITE', par_strength: 'FULL', skip_conditioning=True}

# We support PPO and TRPL, (SAC is untested, we are open to PRs fixing issues)
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': 'FULL', 'skip_conditioning': True}), projection_class=projection, verbose=1)

model.learn(total_timesteps=256)

The supported values for par_strength are:

  • SCALAR: We only learn a single scalar value, that is used along the whole diagonal. No covariance is modeled.

  • DIAG: We learn a diagonal covariance matrix. (e.g. only variances).

  • FULL: We learn a full covariance matrix, induced via Cholesky decomp (except when Wasserstein Projection is used; then we use the Cholesky of the SPD matrix sqrt of the covariance marix).

  • CONT_SCALAR: Same as SCALAR, but the scalar is not global, it is parameterized by the policy net (contextual).

  • CONT_DIAG: Same as DIAG, but the values are not global, they are parameterized by the policy net.

  • CONT_HYBRID: We learn a parameric diagonal, that is scaled by the policy net.

  • CONT_FULL: Same as FULL, but parameterized by the policy net.

License

Since this Repo is an extension to Stable Baselines 3 by DLR-RM, it contains some of it's code. SB3 is licensed under the MIT-License, and so are our extensions.