Tweaked README
This commit is contained in:
		
							parent
							
								
									0b4956873c
								
							
						
					
					
						commit
						8f244455bd
					
				
							
								
								
									
										30
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								README.md
									
									
									
									
									
								
							@ -9,8 +9,8 @@ An extension to Stable Baselines 3. Based on Metastable Baselines 1.
 | 
			
		||||
This repo provides:
 | 
			
		||||
 | 
			
		||||
- An implementation of ["Differentiable Trust Region Layers for Deep Reinforcement Learning" by Fabian Otto et al. (TRPL)](https://arxiv.org/abs/2101.09207)
 | 
			
		||||
- Support for Contextual Covariances (via PCA)
 | 
			
		||||
- Support for Full Covariances (via PCA)
 | 
			
		||||
- Support for Contextual Covariances
 | 
			
		||||
- Support for Full Covariances
 | 
			
		||||
 | 
			
		||||
## Installation
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ projection = 'Wasserstein' # or Frobenius or KL
 | 
			
		||||
 | 
			
		||||
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, verbose=1)
 | 
			
		||||
 | 
			
		||||
model.learn(total_timesteps=100)
 | 
			
		||||
model.learn(total_timesteps=256)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Configure TRPL py passing `projection_kwargs` to TRPL:
 | 
			
		||||
@ -55,7 +55,7 @@ Configure TRPL py passing `projection_kwargs` to TRPL:
 | 
			
		||||
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), projection_class=projection, projection_kwargs={'mean_bound': mean_bound, 'cov_bound': cov_bound}, verbose=1)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
For avaible projection_kwargs have a look at [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections).
 | 
			
		||||
For available projection_kwargs have a look at [Metastable Projections](https://git.dominik-roth.eu/dodox/metastable-projections).
 | 
			
		||||
 | 
			
		||||
### Full Covariance
 | 
			
		||||
 | 
			
		||||
@ -65,27 +65,27 @@ We therefore pass `use_pca=True` and `policy_kwargs.dist_kwargs = {'Base_Noise':
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
# We support PPO and TRPL, (SAC is untested, we are open to PRs fixing issues)
 | 
			
		||||
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strengt        h': 'FULL', 'skip_conditioning': True}), projection_class=projection, verbose=1)
 | 
			
		||||
model = TRPL("MlpPolicy", env_id, n_steps=128, seed=0, use_pca=True, policy_kwargs=dict(net_arch=[16], dist_kwargs={'par_strength': 'FULL', 'skip_conditioning': True}), projection_class=projection, verbose=1)
 | 
			
		||||
 | 
			
		||||
model.learn(total_timesteps=100)
 | 
			
		||||
model.learn(total_timesteps=256)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The supportted values for `par_strength` are:
 | 
			
		||||
The supported values for `par_strength` are:
 | 
			
		||||
- `SCALAR`: We only learn a single scalar value, that is used along the whole diagonal. No covariance is modeled.
 | 
			
		||||
 | 
			
		||||
    SCALAR: We only learn a single scalar value, that is used along the whole diagonal. No covariance is modeled.
 | 
			
		||||
- `DIAG`: We learn a diagonal covariance matrix. (e.g. only variances).
 | 
			
		||||
 | 
			
		||||
    DIAG: We learn a diagonal covariance matrix. (e.g. only variances).
 | 
			
		||||
- `FULL`: We learn a full covariance matrix, induced via Cholesky decomp (except when Wasserstein Projection is used; then we use the Cholesky of the SPD matrix sqrt of the covariance marix).
 | 
			
		||||
 | 
			
		||||
    FULL: We learn a full covariance matrix, induced via cholesky decomp.
 | 
			
		||||
- `CONT_SCALAR`: Same as `SCALAR`, but the scalar is not global, it is parameterized by the policy net (contextual).
 | 
			
		||||
 | 
			
		||||
    CONT_SCALAR: Same as SCALAR, but the scalar is not global, it is parameterized by the policy net.
 | 
			
		||||
- `CONT_DIAG`: Same as `DIAG`, but the values are not global, they are parameterized by the policy net.
 | 
			
		||||
 | 
			
		||||
    CONT_DIAG: Same as DIAG, but the values are not global, they are parameterized by the policy net.
 | 
			
		||||
- `CONT_HYBRID`: We learn a parameric diagonal, that is scaled by the policy net.
 | 
			
		||||
 | 
			
		||||
    CONT_HYBRID: We learn a parameric diagonal, that is scaled by the policy net.
 | 
			
		||||
- `CONT_FULL`: Same as `FULL`, but parameterized by the policy net.
 | 
			
		||||
 | 
			
		||||
    CONT_FULL: Same as FULL, but parameterized by the policy net.
 | 
			
		||||
 | 
			
		||||
## License
 | 
			
		||||
 | 
			
		||||
Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE).
 | 
			
		||||
Since this Repo is an extension to [Stable Baselines 3 by DLR-RM](https://github.com/DLR-RM/stable-baselines3), it contains some of it's code. SB3 is licensed under the [MIT-License](https://github.com/DLR-RM/stable-baselines3/blob/master/LICENSE), and so are our extensions.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user