- fix pyproject

- update hydra config to make experiment overrides smoother
- fix directory naming
- update readme
This commit is contained in:
Axel Brunnbauer 2025-07-15 22:20:32 -07:00
parent bb6889d308
commit 86fd47b04e
29 changed files with 149 additions and 61 deletions

View File

@ -17,7 +17,7 @@ We strongly recommend using the [uv tool](https://docs.astral.sh/uv/getting-star
With uv installed, you can install the project and all dependencies in a local virtual environment under `.venv` with one single command: With uv installed, you can install the project and all dependencies in a local virtual environment under `.venv` with one single command:
```bash ```bash
uv init uv sync
``` ```
Our installation requires a GPU with CUDA 12 compatible drivers. Our installation requires a GPU with CUDA 12 compatible drivers.
@ -36,7 +36,7 @@ pip install -e .
## Running Experiments ## Running Experiments
The main code for the algorithm is in `reppo/jaxrl/reppo.py` and `reppo/torchrl/reppo.py` respectively. The main code for the algorithm is in `src/reppo_jax/reppo.py` and `src/torchrl/reppo.py` respectively.
In our tests, both versions produce similar returns up to seed variance. In our tests, both versions produce similar returns up to seed variance.
However, due to slight variations in the frameworks, we cannot always guarantee this. However, due to slight variations in the frameworks, we cannot always guarantee this.
@ -46,7 +46,7 @@ This can result in cases where the GPU is stalled if the CPU cannot provide inst
Our configurations are handled with [hydra.cc](https://hydra.cc/). This means parameters can be overwritten by using the syntax Our configurations are handled with [hydra.cc](https://hydra.cc/). This means parameters can be overwritten by using the syntax
```bash ```bash
python reppo/jaxrl/reppo.py PARAMETER=VALUE python src/reppo_jax/reppo.py PARAMETER=VALUE
``` ```
By default, the environment type and name need to be provided. By default, the environment type and name need to be provided.
@ -56,11 +56,6 @@ The torch version support `env=mjx_dmc`, and `env=maniskill`. We additionally pr
The paper experiments can be reproduced easily by using the `experiment_override` settings. The paper experiments can be reproduced easily by using the `experiment_override` settings.
By specifying `experiment_override=mjx_smc_small_data` for example, you can run the variant of REPPO with a batch size of 32k samples. By specifying `experiment_override=mjx_smc_small_data` for example, you can run the variant of REPPO with a batch size of 32k samples.
> [!important]
> Note that by default, `experiment_override` overrides any parameters in the default config. This means if you specify `hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`, the number of steps will be 32.
> To appropriately set the number of steps, you would have to specify `experiment_override.hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`.
> In general, we recommend using the experiment overrides only when reproducing paper experiments.
## Contributing ## Contributing
We welcome contributions! Please feel free to submit issues and pull requests. We welcome contributions! Please feel free to submit issues and pull requests.

View File

@ -1,5 +1,5 @@
# @package _global_
hyperparameters:
lmbda: 0.95 lmbda: 0.95
num_epochs: 4 num_epochs: 4
aux_loss_mult: 1.0 aux_loss_mult: 1.0

View File

@ -1,3 +1,5 @@
# @package _global_
hyperparameters:
num_envs: 1024 num_envs: 1024
num_steps: 128 num_steps: 128
num_mini_batches: 64 num_mini_batches: 64

View File

@ -1,3 +1,5 @@
# @package _global_
hyperparameters:
num_envs: 1024 num_envs: 1024
num_steps: 64 num_steps: 64
num_mini_batches: 32 num_mini_batches: 32

View File

@ -1,3 +1,5 @@
# @package _global_
hyperparameters:
num_envs: 1024 num_envs: 1024
num_steps: 32 num_steps: 32
num_mini_batches: 16 num_mini_batches: 16

View File

@ -1,6 +1,7 @@
# @package _global_
hyperparameters:
gamma: 0.97 gamma: 0.97
critic_hidden_dim: 1024 critic_hidden_dim: 1024
num_envs: 1024 num_envs: 1024
num_steps: 128 num_steps: 128
num_mini_batches: 16 num_mini_batches: 16

View File

@ -1,6 +1,7 @@
# @package _global_
hyperparameters:
gamma: 0.97 gamma: 0.97
critic_hidden_dim: 1024 critic_hidden_dim: 1024
num_envs: 1024 num_envs: 1024
num_steps: 32 num_steps: 32
num_mini_batches: 4 num_mini_batches: 4

87
config/reppo.yaml Normal file
View File

@ -0,0 +1,87 @@
defaults:
- env: brax
- platform: torch
- _self_
hyperparameters:
# env and run settings (mostly don't touch)
total_time_steps: 50_000_000
normalize_env: true
max_episode_steps: 1000
eval_interval: 2
num_eval: 20
# optimization settings (seem very stable)
lr: 3e-4
anneal_lr: false
max_grad_norm: 0.5
polyak: 1.0 # maybe ablate ?
# problem discount settings (need tuning)
gamma: 0.99
lmbda: 0.95
lmbda_min: 0.50 # irrelevant if no exploration noise is added
# batch settings (need tuning for MJX humanoid)
num_steps: 128
num_mini_batches: 128
num_envs: 1024
num_epochs: 4
# exploration settings (currently not touched)
exploration_noise_max: 1.0
exploration_noise_min: 1.0
exploration_base_envs: 0
# critic architecture settings (need to be increased for MJX humanoid)
critic_hidden_dim: 512
actor_hidden_dim: 512
vmin: ${env.vmin}
vmax: ${env.vmax}
num_bins: 151
hl_gauss: true
use_critic_norm: true
num_critic_encoder_layers: 2
num_critic_head_layers: 2
num_critic_pred_layers: 2
use_simplical_embedding: False
# actor architecture settings (seem stable)
use_actor_norm: true
num_actor_layers: 3
actor_min_std: 0.0
# actor & critic loss settings (seem remarkably stable)
## kl settings
kl_start: 0.01
kl_bound: 0.1 # switched to tighter bounds for MJX
reduce_kl: true
reverse_kl: false # previous default "false"
update_kl_lagrangian: true
actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
## entropy settings
ent_start: 0.01
ent_target_mult: 0.5
update_entropy_lagrangian: true
## auxiliary loss settings
aux_loss_mult: 1.0
measure_burnin: 3
name: "sac"
seed: 0
num_seeds: 1
tune: false
checkpoint_dir: null
num_trials: 10
tags: ["experimental"]
wandb:
mode: "online" # set to online to activate wandb
entity: "viper_svg"
project: "online_sac"
hydra:
job:
chdir: True

View File

@ -1,7 +1,7 @@
[project] [project]
name = "Relative Entropy Pathwise Policy Optimization" name = "reppo"
version = "0.1.0" version = "0.1.0"
description = "Code release for the REPPO paper" description = "Code release for the 'Relative Entropy Pathwise Policy Optimization'."
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
@ -26,7 +26,7 @@ dependencies = [
"tensordict>=0.8.3", "tensordict>=0.8.3",
"torch>=2.7.1", "torch>=2.7.1",
"tyro>=0.9.25", "tyro>=0.9.25",
"sapien>=3.0.0b1", "sapien>=3.0.0b1 ; sys_platform != 'darwin'",
"wandb>=0.20.1", "wandb>=0.20.1",
"torchinfo>=1.8.0", "torchinfo>=1.8.0",
"debugpy>=1.8.14", "debugpy>=1.8.14",

View File

@ -18,14 +18,14 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
import wandb import wandb
from reppo.env_utils.jax_wrappers import ( from src.env_utils.jax_wrappers import (
BraxGymnaxWrapper, BraxGymnaxWrapper,
ClipAction, ClipAction,
LogWrapper, LogWrapper,
MjxGymnaxWrapper, MjxGymnaxWrapper,
) )
from reppo.jaxrl import utils from src.jaxrl import utils
from reppo.jaxrl.normalization import NormalizationState, Normalizer from src.jaxrl.normalization import NormalizationState, Normalizer
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

View File

@ -17,15 +17,15 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
import wandb import wandb
from reppo.env_utils.jax_wrappers import ( from src.env_utils.jax_wrappers import (
BraxGymnaxWrapper, BraxGymnaxWrapper,
ClipAction, ClipAction,
LogWrapper, LogWrapper,
MjxGymnaxWrapper, MjxGymnaxWrapper,
NormalizeVec, NormalizeVec,
) )
from reppo.jaxrl import utils from src.jaxrl import utils
from reppo.network_utils.jax_models import ( from src.network_utils.jax_models import (
CategoricalCriticNetwork, CategoricalCriticNetwork,
CriticNetwork, CriticNetwork,
SACActorNetworks, SACActorNetworks,
@ -928,10 +928,8 @@ def run(cfg: DictConfig, trial: optuna.Trial | None) -> float:
return (0.1 * sweep_metrics_array.mean() + sweep_metrics_array[:, -1].mean()).item() return (0.1 * sweep_metrics_array.mean() + sweep_metrics_array[:, -1].mean()).item()
@hydra.main(version_base=None, config_path="../../config", config_name="sac") @hydra.main(version_base=None, config_path="../../config", config_name="reppo")
def main(cfg: DictConfig): def main(cfg: DictConfig):
cfg.hyperparameters = OmegaConf.merge(cfg.hyperparameters, cfg.experiment_overrides)
run(cfg, trial=None) run(cfg, trial=None)

View File

@ -6,7 +6,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax import nnx from flax import nnx
from reppo.jaxrl import utils from src.jaxrl import utils
def torch_he_uniform( def torch_he_uniform(

View File

@ -4,7 +4,7 @@ from torch.distributions import constraints
from torch.distributions.transforms import Transform from torch.distributions.transforms import Transform
from torch.distributions.normal import Normal from torch.distributions.normal import Normal
from reppo.torchrl.reppo import hl_gauss from src.torchrl.reppo import hl_gauss
class TanhTransform(Transform): class TanhTransform(Transform):

View File

@ -26,9 +26,9 @@ import torch.optim as optim
from torchinfo import summary from torchinfo import summary
from tensordict import TensorDict from tensordict import TensorDict
from torch.amp import GradScaler from torch.amp import GradScaler
from reppo.torchrl.envs import make_envs from src.torchrl.envs import make_envs
from reppo.network_utils.torch_models import Actor, Critic from src.network_utils.torch_models import Actor, Critic
from reppo.torchrl.reppo import ( from src.torchrl.reppo import (
EmpiricalNormalization, EmpiricalNormalization,
hl_gauss, hl_gauss,
) )