- fix pyproject

- update hydra config to make experiment overrides smoother
- fix directory naming
- update readme
This commit is contained in:
Axel Brunnbauer 2025-07-15 22:20:32 -07:00
parent bb6889d308
commit 86fd47b04e
29 changed files with 149 additions and 61 deletions

View File

@ -17,7 +17,7 @@ We strongly recommend using the [uv tool](https://docs.astral.sh/uv/getting-star
With uv installed, you can install the project and all dependencies in a local virtual environment under `.venv` with one single command:
```bash
uv init
uv sync
```
Our installation requires a GPU with CUDA 12 compatible drivers.
@ -36,7 +36,7 @@ pip install -e .
## Running Experiments
The main code for the algorithm is in `reppo/jaxrl/reppo.py` and `reppo/torchrl/reppo.py` respectively.
The main code for the algorithm is in `src/reppo_jax/reppo.py` and `src/torchrl/reppo.py` respectively.
In our tests, both versions produce similar returns up to seed variance.
However, due to slight variations in the frameworks, we cannot always guarantee this.
@ -46,7 +46,7 @@ This can result in cases where the GPU is stalled if the CPU cannot provide inst
Our configurations are handled with [hydra.cc](https://hydra.cc/). This means parameters can be overwritten by using the syntax
```bash
python reppo/jaxrl/reppo.py PARAMETER=VALUE
python src/reppo_jax/reppo.py PARAMETER=VALUE
```
By default, the environment type and name need to be provided.
@ -56,11 +56,6 @@ The torch version support `env=mjx_dmc`, and `env=maniskill`. We additionally pr
The paper experiments can be reproduced easily by using the `experiment_override` settings.
By specifying `experiment_override=mjx_smc_small_data` for example, you can run the variant of REPPO with a batch size of 32k samples.
> [!important]
> Note that by default, `experiment_override` overrides any parameters in the default config. This means if you specify `hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`, the number of steps will be 32.
> To appropriately set the number of steps, you would have to specify `experiment_override.hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`.
> In general, we recommend using the experiment overrides only when reproducing paper experiments.
## Contributing
We welcome contributions! Please feel free to submit issues and pull requests.

View File

@ -1,5 +1,5 @@
lmbda: 0.95
num_epochs: 4
aux_loss_mult: 1.0
# @package _global_
hyperparameters:
lmbda: 0.95
num_epochs: 4
aux_loss_mult: 1.0

View File

@ -1,5 +1,7 @@
num_envs: 1024
num_steps: 128
num_mini_batches: 64
num_epochs: 8
kl_bound: 0.1
# @package _global_
hyperparameters:
num_envs: 1024
num_steps: 128
num_mini_batches: 64
num_epochs: 8
kl_bound: 0.1

View File

@ -1,5 +1,7 @@
num_envs: 1024
num_steps: 64
num_mini_batches: 32
num_epochs: 8
kl_bound: 0.1
# @package _global_
hyperparameters:
num_envs: 1024
num_steps: 64
num_mini_batches: 32
num_epochs: 8
kl_bound: 0.1

View File

@ -1,5 +1,7 @@
num_envs: 1024
num_steps: 32
num_mini_batches: 16
num_epochs: 8
kl_bound: 0.1
# @package _global_
hyperparameters:
num_envs: 1024
num_steps: 32
num_mini_batches: 16
num_epochs: 8
kl_bound: 0.1

View File

@ -1,8 +1,9 @@
gamma: 0.97
critic_hidden_dim: 1024
num_envs: 1024
num_steps: 128
num_mini_batches: 16
num_epochs: 8
kl_bound: 0.1
# @package _global_
hyperparameters:
gamma: 0.97
critic_hidden_dim: 1024
num_envs: 1024
num_steps: 128
num_mini_batches: 16
num_epochs: 8
kl_bound: 0.1

View File

@ -1,8 +1,9 @@
gamma: 0.97
critic_hidden_dim: 1024
num_envs: 1024
num_steps: 32
num_mini_batches: 4
num_epochs: 8
kl_bound: 0.1
# @package _global_
hyperparameters:
gamma: 0.97
critic_hidden_dim: 1024
num_envs: 1024
num_steps: 32
num_mini_batches: 4
num_epochs: 8
kl_bound: 0.1

87
config/reppo.yaml Normal file
View File

@ -0,0 +1,87 @@
defaults:
- env: brax
- platform: torch
- _self_
hyperparameters:
# env and run settings (mostly don't touch)
total_time_steps: 50_000_000
normalize_env: true
max_episode_steps: 1000
eval_interval: 2
num_eval: 20
# optimization settings (seem very stable)
lr: 3e-4
anneal_lr: false
max_grad_norm: 0.5
polyak: 1.0 # maybe ablate ?
# problem discount settings (need tuning)
gamma: 0.99
lmbda: 0.95
lmbda_min: 0.50 # irrelevant if no exploration noise is added
# batch settings (need tuning for MJX humanoid)
num_steps: 128
num_mini_batches: 128
num_envs: 1024
num_epochs: 4
# exploration settings (currently not touched)
exploration_noise_max: 1.0
exploration_noise_min: 1.0
exploration_base_envs: 0
# critic architecture settings (need to be increased for MJX humanoid)
critic_hidden_dim: 512
actor_hidden_dim: 512
vmin: ${env.vmin}
vmax: ${env.vmax}
num_bins: 151
hl_gauss: true
use_critic_norm: true
num_critic_encoder_layers: 2
num_critic_head_layers: 2
num_critic_pred_layers: 2
use_simplical_embedding: False
# actor architecture settings (seem stable)
use_actor_norm: true
num_actor_layers: 3
actor_min_std: 0.0
# actor & critic loss settings (seem remarkably stable)
## kl settings
kl_start: 0.01
kl_bound: 0.1 # switched to tighter bounds for MJX
reduce_kl: true
reverse_kl: false # previous default "false"
update_kl_lagrangian: true
actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
## entropy settings
ent_start: 0.01
ent_target_mult: 0.5
update_entropy_lagrangian: true
## auxiliary loss settings
aux_loss_mult: 1.0
measure_burnin: 3
name: "sac"
seed: 0
num_seeds: 1
tune: false
checkpoint_dir: null
num_trials: 10
tags: ["experimental"]
wandb:
mode: "online" # set to online to activate wandb
entity: "viper_svg"
project: "online_sac"
hydra:
job:
chdir: True

View File

@ -1,7 +1,7 @@
[project]
name = "Relative Entropy Pathwise Policy Optimization"
name = "reppo"
version = "0.1.0"
description = "Code release for the REPPO paper"
description = "Code release for the 'Relative Entropy Pathwise Policy Optimization'."
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
@ -26,7 +26,7 @@ dependencies = [
"tensordict>=0.8.3",
"torch>=2.7.1",
"tyro>=0.9.25",
"sapien>=3.0.0b1",
"sapien>=3.0.0b1 ; sys_platform != 'darwin'",
"wandb>=0.20.1",
"torchinfo>=1.8.0",
"debugpy>=1.8.14",

View File

@ -18,14 +18,14 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf
import wandb
from reppo.env_utils.jax_wrappers import (
from src.env_utils.jax_wrappers import (
BraxGymnaxWrapper,
ClipAction,
LogWrapper,
MjxGymnaxWrapper,
)
from reppo.jaxrl import utils
from reppo.jaxrl.normalization import NormalizationState, Normalizer
from src.jaxrl import utils
from src.jaxrl.normalization import NormalizationState, Normalizer
logging.basicConfig(level=logging.INFO)

View File

@ -17,15 +17,15 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf
import wandb
from reppo.env_utils.jax_wrappers import (
from src.env_utils.jax_wrappers import (
BraxGymnaxWrapper,
ClipAction,
LogWrapper,
MjxGymnaxWrapper,
NormalizeVec,
)
from reppo.jaxrl import utils
from reppo.network_utils.jax_models import (
from src.jaxrl import utils
from src.network_utils.jax_models import (
CategoricalCriticNetwork,
CriticNetwork,
SACActorNetworks,
@ -928,10 +928,8 @@ def run(cfg: DictConfig, trial: optuna.Trial | None) -> float:
return (0.1 * sweep_metrics_array.mean() + sweep_metrics_array[:, -1].mean()).item()
@hydra.main(version_base=None, config_path="../../config", config_name="sac")
@hydra.main(version_base=None, config_path="../../config", config_name="reppo")
def main(cfg: DictConfig):
cfg.hyperparameters = OmegaConf.merge(cfg.hyperparameters, cfg.experiment_overrides)
run(cfg, trial=None)

View File

@ -6,7 +6,7 @@ import jax
import jax.numpy as jnp
from flax import nnx
from reppo.jaxrl import utils
from src.jaxrl import utils
def torch_he_uniform(

View File

@ -4,7 +4,7 @@ from torch.distributions import constraints
from torch.distributions.transforms import Transform
from torch.distributions.normal import Normal
from reppo.torchrl.reppo import hl_gauss
from src.torchrl.reppo import hl_gauss
class TanhTransform(Transform):

View File

@ -26,9 +26,9 @@ import torch.optim as optim
from torchinfo import summary
from tensordict import TensorDict
from torch.amp import GradScaler
from reppo.torchrl.envs import make_envs
from reppo.network_utils.torch_models import Actor, Critic
from reppo.torchrl.reppo import (
from src.torchrl.envs import make_envs
from src.network_utils.torch_models import Actor, Critic
from src.torchrl.reppo import (
EmpiricalNormalization,
hl_gauss,
)