From eed4363ddda335025313d49bf2cb7759caed3122 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 30 Mar 2024 14:41:43 +0100 Subject: [PATCH] Moved README.md --- metastable_baselines2/README.md => README.md | 44 ++++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) rename metastable_baselines2/README.md => README.md (53%) diff --git a/metastable_baselines2/README.md b/README.md similarity index 53% rename from metastable_baselines2/README.md rename to README.md index 2153285..34cbceb 100644 --- a/metastable_baselines2/README.md +++ b/README.md @@ -26,27 +26,65 @@ KL Projections require ALR's ITPAL as an additional dependecy. Then install this repo as a package: -``` +```bash pip install -e . ``` ## Usage +### TRPL + TRPL can be used just like SB3's PPO: -``` +```python import gymnasium as gym from metastable_baselines2 import TRPL +env_id = 'LunarLanderContinuous-v2' projection = 'Wasserstein' # or Frobenius or KL -model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1) +model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, verbose=1) model.learn(total_timesteps=100) ``` +Configure TRPL py passing `projection_kwargs` to TRPL: + +```python +model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1) +``` + For avaible projection_kwargs have a look at [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections). +### Full Covariance + +SB3 does not support full covariances (only diagonal). We still provide support for full covariances via the seperate PCA package. (But since we don't actually want to use PCA ('Prior Conditioned Annealing'), we pass 'skip_conditioning=True'; this will lead to the underlying Noise being used directly.) + +We therefore pass `use_pca=True` and `policy_kwargs.dist_kwargs = {'Base_Noise': 'WHITE', par_strength: 'FULL', skip_conditioning=True}` + +```python +# We support PPO and TRPL, (SAC is untested, we are open to PRs fixing issues) +model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16], ), projection_class=projection, verbose=1) + +model.learn(total_timesteps=100) +``` + +The supportted values for `par_strength` are: + +​ SCALAR: We only learn a single scalar value, that is used along the whole diagonal. No covariance is modeled. + +​ DIAG: We learn a diagonal covariance matrix. (e.g. only variances). + +​ FULL: We learn a full covariance matrix, induced via cholesky decomp. + +​ CONT_SCALAR: Same as SCALAR, but the scalar is not global, it is parameterized by the policy net. + +​ CONT_DIAG: Same as DIAG, but the values are not global, they are parameterized by the policy net. + +​ CONT_HYBRID: We learn a parameric diagonal, that is scaled by the policy net. + +​ CONT_FULL: Same as FULL, but parameterized by the policy net. + ## License Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE).