- fix pyproject
- update hydra config to make experiment overrides smoother - fix directory naming - update readme
This commit is contained in:
parent
bb6889d308
commit
86fd47b04e
11
README.md
11
README.md
@ -17,7 +17,7 @@ We strongly recommend using the [uv tool](https://docs.astral.sh/uv/getting-star
|
||||
|
||||
With uv installed, you can install the project and all dependencies in a local virtual environment under `.venv` with one single command:
|
||||
```bash
|
||||
uv init
|
||||
uv sync
|
||||
```
|
||||
|
||||
Our installation requires a GPU with CUDA 12 compatible drivers.
|
||||
@ -36,7 +36,7 @@ pip install -e .
|
||||
|
||||
## Running Experiments
|
||||
|
||||
The main code for the algorithm is in `reppo/jaxrl/reppo.py` and `reppo/torchrl/reppo.py` respectively.
|
||||
The main code for the algorithm is in `src/reppo_jax/reppo.py` and `src/torchrl/reppo.py` respectively.
|
||||
In our tests, both versions produce similar returns up to seed variance.
|
||||
However, due to slight variations in the frameworks, we cannot always guarantee this.
|
||||
|
||||
@ -46,7 +46,7 @@ This can result in cases where the GPU is stalled if the CPU cannot provide inst
|
||||
|
||||
Our configurations are handled with [hydra.cc](https://hydra.cc/). This means parameters can be overwritten by using the syntax
|
||||
```bash
|
||||
python reppo/jaxrl/reppo.py PARAMETER=VALUE
|
||||
python src/reppo_jax/reppo.py PARAMETER=VALUE
|
||||
```
|
||||
|
||||
By default, the environment type and name need to be provided.
|
||||
@ -56,11 +56,6 @@ The torch version support `env=mjx_dmc`, and `env=maniskill`. We additionally pr
|
||||
The paper experiments can be reproduced easily by using the `experiment_override` settings.
|
||||
By specifying `experiment_override=mjx_smc_small_data` for example, you can run the variant of REPPO with a batch size of 32k samples.
|
||||
|
||||
> [!important]
|
||||
> Note that by default, `experiment_override` overrides any parameters in the default config. This means if you specify `hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`, the number of steps will be 32.
|
||||
> To appropriately set the number of steps, you would have to specify `experiment_override.hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`.
|
||||
> In general, we recommend using the experiment overrides only when reproducing paper experiments.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! Please feel free to submit issues and pull requests.
|
||||
|
@ -1,5 +1,5 @@
|
||||
lmbda: 0.95
|
||||
|
||||
num_epochs: 4
|
||||
|
||||
aux_loss_mult: 1.0
|
||||
# @package _global_
|
||||
hyperparameters:
|
||||
lmbda: 0.95
|
||||
num_epochs: 4
|
||||
aux_loss_mult: 1.0
|
@ -1,5 +1,7 @@
|
||||
num_envs: 1024
|
||||
num_steps: 128
|
||||
num_mini_batches: 64
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
||||
# @package _global_
|
||||
hyperparameters:
|
||||
num_envs: 1024
|
||||
num_steps: 128
|
||||
num_mini_batches: 64
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
@ -1,5 +1,7 @@
|
||||
num_envs: 1024
|
||||
num_steps: 64
|
||||
num_mini_batches: 32
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
||||
# @package _global_
|
||||
hyperparameters:
|
||||
num_envs: 1024
|
||||
num_steps: 64
|
||||
num_mini_batches: 32
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
@ -1,5 +1,7 @@
|
||||
num_envs: 1024
|
||||
num_steps: 32
|
||||
num_mini_batches: 16
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
||||
# @package _global_
|
||||
hyperparameters:
|
||||
num_envs: 1024
|
||||
num_steps: 32
|
||||
num_mini_batches: 16
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
@ -1,8 +1,9 @@
|
||||
gamma: 0.97
|
||||
critic_hidden_dim: 1024
|
||||
|
||||
num_envs: 1024
|
||||
num_steps: 128
|
||||
num_mini_batches: 16
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
||||
# @package _global_
|
||||
hyperparameters:
|
||||
gamma: 0.97
|
||||
critic_hidden_dim: 1024
|
||||
num_envs: 1024
|
||||
num_steps: 128
|
||||
num_mini_batches: 16
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
@ -1,8 +1,9 @@
|
||||
gamma: 0.97
|
||||
critic_hidden_dim: 1024
|
||||
|
||||
num_envs: 1024
|
||||
num_steps: 32
|
||||
num_mini_batches: 4
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
||||
# @package _global_
|
||||
hyperparameters:
|
||||
gamma: 0.97
|
||||
critic_hidden_dim: 1024
|
||||
num_envs: 1024
|
||||
num_steps: 32
|
||||
num_mini_batches: 4
|
||||
num_epochs: 8
|
||||
kl_bound: 0.1
|
87
config/reppo.yaml
Normal file
87
config/reppo.yaml
Normal file
@ -0,0 +1,87 @@
|
||||
defaults:
|
||||
- env: brax
|
||||
- platform: torch
|
||||
- _self_
|
||||
|
||||
hyperparameters:
|
||||
# env and run settings (mostly don't touch)
|
||||
total_time_steps: 50_000_000
|
||||
normalize_env: true
|
||||
max_episode_steps: 1000
|
||||
eval_interval: 2
|
||||
num_eval: 20
|
||||
|
||||
# optimization settings (seem very stable)
|
||||
lr: 3e-4
|
||||
anneal_lr: false
|
||||
max_grad_norm: 0.5
|
||||
polyak: 1.0 # maybe ablate ?
|
||||
|
||||
# problem discount settings (need tuning)
|
||||
gamma: 0.99
|
||||
lmbda: 0.95
|
||||
lmbda_min: 0.50 # irrelevant if no exploration noise is added
|
||||
|
||||
# batch settings (need tuning for MJX humanoid)
|
||||
num_steps: 128
|
||||
num_mini_batches: 128
|
||||
num_envs: 1024
|
||||
num_epochs: 4
|
||||
|
||||
# exploration settings (currently not touched)
|
||||
exploration_noise_max: 1.0
|
||||
exploration_noise_min: 1.0
|
||||
exploration_base_envs: 0
|
||||
|
||||
# critic architecture settings (need to be increased for MJX humanoid)
|
||||
critic_hidden_dim: 512
|
||||
actor_hidden_dim: 512
|
||||
vmin: ${env.vmin}
|
||||
vmax: ${env.vmax}
|
||||
num_bins: 151
|
||||
hl_gauss: true
|
||||
use_critic_norm: true
|
||||
num_critic_encoder_layers: 2
|
||||
num_critic_head_layers: 2
|
||||
num_critic_pred_layers: 2
|
||||
use_simplical_embedding: False
|
||||
|
||||
# actor architecture settings (seem stable)
|
||||
use_actor_norm: true
|
||||
num_actor_layers: 3
|
||||
actor_min_std: 0.0
|
||||
|
||||
# actor & critic loss settings (seem remarkably stable)
|
||||
## kl settings
|
||||
kl_start: 0.01
|
||||
kl_bound: 0.1 # switched to tighter bounds for MJX
|
||||
reduce_kl: true
|
||||
reverse_kl: false # previous default "false"
|
||||
update_kl_lagrangian: true
|
||||
actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
|
||||
## entropy settings
|
||||
ent_start: 0.01
|
||||
ent_target_mult: 0.5
|
||||
update_entropy_lagrangian: true
|
||||
## auxiliary loss settings
|
||||
aux_loss_mult: 1.0
|
||||
|
||||
|
||||
measure_burnin: 3
|
||||
|
||||
|
||||
name: "sac"
|
||||
seed: 0
|
||||
num_seeds: 1
|
||||
tune: false
|
||||
checkpoint_dir: null
|
||||
num_trials: 10
|
||||
tags: ["experimental"]
|
||||
wandb:
|
||||
mode: "online" # set to online to activate wandb
|
||||
entity: "viper_svg"
|
||||
project: "online_sac"
|
||||
|
||||
hydra:
|
||||
job:
|
||||
chdir: True
|
@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "Relative Entropy Pathwise Policy Optimization"
|
||||
name = "reppo"
|
||||
version = "0.1.0"
|
||||
description = "Code release for the REPPO paper"
|
||||
description = "Code release for the 'Relative Entropy Pathwise Policy Optimization'."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
@ -26,7 +26,7 @@ dependencies = [
|
||||
"tensordict>=0.8.3",
|
||||
"torch>=2.7.1",
|
||||
"tyro>=0.9.25",
|
||||
"sapien>=3.0.0b1",
|
||||
"sapien>=3.0.0b1 ; sys_platform != 'darwin'",
|
||||
"wandb>=0.20.1",
|
||||
"torchinfo>=1.8.0",
|
||||
"debugpy>=1.8.14",
|
||||
|
@ -18,14 +18,14 @@ from jax.random import PRNGKey
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import wandb
|
||||
from reppo.env_utils.jax_wrappers import (
|
||||
from src.env_utils.jax_wrappers import (
|
||||
BraxGymnaxWrapper,
|
||||
ClipAction,
|
||||
LogWrapper,
|
||||
MjxGymnaxWrapper,
|
||||
)
|
||||
from reppo.jaxrl import utils
|
||||
from reppo.jaxrl.normalization import NormalizationState, Normalizer
|
||||
from src.jaxrl import utils
|
||||
from src.jaxrl.normalization import NormalizationState, Normalizer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -17,15 +17,15 @@ from jax.random import PRNGKey
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import wandb
|
||||
from reppo.env_utils.jax_wrappers import (
|
||||
from src.env_utils.jax_wrappers import (
|
||||
BraxGymnaxWrapper,
|
||||
ClipAction,
|
||||
LogWrapper,
|
||||
MjxGymnaxWrapper,
|
||||
NormalizeVec,
|
||||
)
|
||||
from reppo.jaxrl import utils
|
||||
from reppo.network_utils.jax_models import (
|
||||
from src.jaxrl import utils
|
||||
from src.network_utils.jax_models import (
|
||||
CategoricalCriticNetwork,
|
||||
CriticNetwork,
|
||||
SACActorNetworks,
|
||||
@ -928,10 +928,8 @@ def run(cfg: DictConfig, trial: optuna.Trial | None) -> float:
|
||||
return (0.1 * sweep_metrics_array.mean() + sweep_metrics_array[:, -1].mean()).item()
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../config", config_name="sac")
|
||||
@hydra.main(version_base=None, config_path="../../config", config_name="reppo")
|
||||
def main(cfg: DictConfig):
|
||||
cfg.hyperparameters = OmegaConf.merge(cfg.hyperparameters, cfg.experiment_overrides)
|
||||
|
||||
run(cfg, trial=None)
|
||||
|
||||
|
@ -6,7 +6,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import nnx
|
||||
|
||||
from reppo.jaxrl import utils
|
||||
from src.jaxrl import utils
|
||||
|
||||
|
||||
def torch_he_uniform(
|
@ -4,7 +4,7 @@ from torch.distributions import constraints
|
||||
from torch.distributions.transforms import Transform
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
from reppo.torchrl.reppo import hl_gauss
|
||||
from src.torchrl.reppo import hl_gauss
|
||||
|
||||
|
||||
class TanhTransform(Transform):
|
@ -26,9 +26,9 @@ import torch.optim as optim
|
||||
from torchinfo import summary
|
||||
from tensordict import TensorDict
|
||||
from torch.amp import GradScaler
|
||||
from reppo.torchrl.envs import make_envs
|
||||
from reppo.network_utils.torch_models import Actor, Critic
|
||||
from reppo.torchrl.reppo import (
|
||||
from src.torchrl.envs import make_envs
|
||||
from src.network_utils.torch_models import Actor, Critic
|
||||
from src.torchrl.reppo import (
|
||||
EmpiricalNormalization,
|
||||
hl_gauss,
|
||||
)
|
Loading…
Reference in New Issue
Block a user