diff --git a/README.md b/README.md index 077b863..b3f3de8 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ We strongly recommend using the [uv tool](https://docs.astral.sh/uv/getting-star With uv installed, you can install the project and all dependencies in a local virtual environment under `.venv` with one single command: ```bash -uv init +uv sync ``` Our installation requires a GPU with CUDA 12 compatible drivers. @@ -36,7 +36,7 @@ pip install -e . ## Running Experiments -The main code for the algorithm is in `reppo/jaxrl/reppo.py` and `reppo/torchrl/reppo.py` respectively. +The main code for the algorithm is in `src/reppo_jax/reppo.py` and `src/torchrl/reppo.py` respectively. In our tests, both versions produce similar returns up to seed variance. However, due to slight variations in the frameworks, we cannot always guarantee this. @@ -46,7 +46,7 @@ This can result in cases where the GPU is stalled if the CPU cannot provide inst Our configurations are handled with [hydra.cc](https://hydra.cc/). This means parameters can be overwritten by using the syntax ```bash -python reppo/jaxrl/reppo.py PARAMETER=VALUE +python src/reppo_jax/reppo.py PARAMETER=VALUE ``` By default, the environment type and name need to be provided. @@ -56,11 +56,6 @@ The torch version support `env=mjx_dmc`, and `env=maniskill`. We additionally pr The paper experiments can be reproduced easily by using the `experiment_override` settings. By specifying `experiment_override=mjx_smc_small_data` for example, you can run the variant of REPPO with a batch size of 32k samples. -> [!important] -> Note that by default, `experiment_override` overrides any parameters in the default config. This means if you specify `hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`, the number of steps will be 32. -> To appropriately set the number of steps, you would have to specify `experiment_override.hyperparameters.num_steps=64 experiment_override=mjx_smc_small_data`. -> In general, we recommend using the experiment overrides only when reproducing paper experiments. - ## Contributing We welcome contributions! Please feel free to submit issues and pull requests. diff --git a/config/experiment_overrides/maniskill.yaml b/config/experiment_overrides/maniskill.yaml index d33b273..8e28b8e 100644 --- a/config/experiment_overrides/maniskill.yaml +++ b/config/experiment_overrides/maniskill.yaml @@ -1,5 +1,5 @@ -lmbda: 0.95 - -num_epochs: 4 - -aux_loss_mult: 1.0 \ No newline at end of file +# @package _global_ +hyperparameters: + lmbda: 0.95 + num_epochs: 4 + aux_loss_mult: 1.0 \ No newline at end of file diff --git a/config/experiment_overrides/mjx_dmc_large_data.yaml b/config/experiment_overrides/mjx_dmc_large_data.yaml index ecb4f92..c5a29bf 100644 --- a/config/experiment_overrides/mjx_dmc_large_data.yaml +++ b/config/experiment_overrides/mjx_dmc_large_data.yaml @@ -1,5 +1,7 @@ -num_envs: 1024 -num_steps: 128 -num_mini_batches: 64 -num_epochs: 8 -kl_bound: 0.1 \ No newline at end of file +# @package _global_ +hyperparameters: + num_envs: 1024 + num_steps: 128 + num_mini_batches: 64 + num_epochs: 8 + kl_bound: 0.1 \ No newline at end of file diff --git a/config/experiment_overrides/mjx_dmc_medium_data.yaml b/config/experiment_overrides/mjx_dmc_medium_data.yaml index 5a53de1..c5b01b3 100644 --- a/config/experiment_overrides/mjx_dmc_medium_data.yaml +++ b/config/experiment_overrides/mjx_dmc_medium_data.yaml @@ -1,5 +1,7 @@ -num_envs: 1024 -num_steps: 64 -num_mini_batches: 32 -num_epochs: 8 -kl_bound: 0.1 \ No newline at end of file +# @package _global_ +hyperparameters: + num_envs: 1024 + num_steps: 64 + num_mini_batches: 32 + num_epochs: 8 + kl_bound: 0.1 \ No newline at end of file diff --git a/config/experiment_overrides/mjx_dmc_small_data.yaml b/config/experiment_overrides/mjx_dmc_small_data.yaml index 8ca8427..0316ea2 100644 --- a/config/experiment_overrides/mjx_dmc_small_data.yaml +++ b/config/experiment_overrides/mjx_dmc_small_data.yaml @@ -1,5 +1,7 @@ -num_envs: 1024 -num_steps: 32 -num_mini_batches: 16 -num_epochs: 8 -kl_bound: 0.1 \ No newline at end of file +# @package _global_ +hyperparameters: + num_envs: 1024 + num_steps: 32 + num_mini_batches: 16 + num_epochs: 8 + kl_bound: 0.1 \ No newline at end of file diff --git a/config/experiment_overrides/mjx_humanoid_large_data.yaml b/config/experiment_overrides/mjx_humanoid_large_data.yaml index 3b44672..fc65871 100644 --- a/config/experiment_overrides/mjx_humanoid_large_data.yaml +++ b/config/experiment_overrides/mjx_humanoid_large_data.yaml @@ -1,8 +1,9 @@ -gamma: 0.97 -critic_hidden_dim: 1024 - -num_envs: 1024 -num_steps: 128 -num_mini_batches: 16 -num_epochs: 8 -kl_bound: 0.1 \ No newline at end of file +# @package _global_ +hyperparameters: + gamma: 0.97 + critic_hidden_dim: 1024 + num_envs: 1024 + num_steps: 128 + num_mini_batches: 16 + num_epochs: 8 + kl_bound: 0.1 \ No newline at end of file diff --git a/config/experiment_overrides/mjx_humanoid_small_data.yaml b/config/experiment_overrides/mjx_humanoid_small_data.yaml index 0d1d88b..c46f52e 100644 --- a/config/experiment_overrides/mjx_humanoid_small_data.yaml +++ b/config/experiment_overrides/mjx_humanoid_small_data.yaml @@ -1,8 +1,9 @@ -gamma: 0.97 -critic_hidden_dim: 1024 - -num_envs: 1024 -num_steps: 32 -num_mini_batches: 4 -num_epochs: 8 -kl_bound: 0.1 \ No newline at end of file +# @package _global_ +hyperparameters: + gamma: 0.97 + critic_hidden_dim: 1024 + num_envs: 1024 + num_steps: 32 + num_mini_batches: 4 + num_epochs: 8 + kl_bound: 0.1 \ No newline at end of file diff --git a/config/reppo.yaml b/config/reppo.yaml new file mode 100644 index 0000000..aec3af7 --- /dev/null +++ b/config/reppo.yaml @@ -0,0 +1,87 @@ +defaults: + - env: brax + - platform: torch + - _self_ + +hyperparameters: + # env and run settings (mostly don't touch) + total_time_steps: 50_000_000 + normalize_env: true + max_episode_steps: 1000 + eval_interval: 2 + num_eval: 20 + + # optimization settings (seem very stable) + lr: 3e-4 + anneal_lr: false + max_grad_norm: 0.5 + polyak: 1.0 # maybe ablate ? + + # problem discount settings (need tuning) + gamma: 0.99 + lmbda: 0.95 + lmbda_min: 0.50 # irrelevant if no exploration noise is added + + # batch settings (need tuning for MJX humanoid) + num_steps: 128 + num_mini_batches: 128 + num_envs: 1024 + num_epochs: 4 + + # exploration settings (currently not touched) + exploration_noise_max: 1.0 + exploration_noise_min: 1.0 + exploration_base_envs: 0 + + # critic architecture settings (need to be increased for MJX humanoid) + critic_hidden_dim: 512 + actor_hidden_dim: 512 + vmin: ${env.vmin} + vmax: ${env.vmax} + num_bins: 151 + hl_gauss: true + use_critic_norm: true + num_critic_encoder_layers: 2 + num_critic_head_layers: 2 + num_critic_pred_layers: 2 + use_simplical_embedding: False + + # actor architecture settings (seem stable) + use_actor_norm: true + num_actor_layers: 3 + actor_min_std: 0.0 + + # actor & critic loss settings (seem remarkably stable) + ## kl settings + kl_start: 0.01 + kl_bound: 0.1 # switched to tighter bounds for MJX + reduce_kl: true + reverse_kl: false # previous default "false" + update_kl_lagrangian: true + actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value" + ## entropy settings + ent_start: 0.01 + ent_target_mult: 0.5 + update_entropy_lagrangian: true + ## auxiliary loss settings + aux_loss_mult: 1.0 + + +measure_burnin: 3 + + +name: "sac" +seed: 0 +num_seeds: 1 +tune: false +checkpoint_dir: null +num_trials: 10 +tags: ["experimental"] +wandb: + mode: "online" # set to online to activate wandb + entity: "viper_svg" + project: "online_sac" + +hydra: + job: + chdir: True diff --git a/pyproject.toml b/pyproject.toml index 49bc69c..61223ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "Relative Entropy Pathwise Policy Optimization" +name = "reppo" version = "0.1.0" -description = "Code release for the REPPO paper" +description = "Code release for the 'Relative Entropy Pathwise Policy Optimization'." readme = "README.md" requires-python = ">=3.12" dependencies = [ @@ -26,7 +26,7 @@ dependencies = [ "tensordict>=0.8.3", "torch>=2.7.1", "tyro>=0.9.25", - "sapien>=3.0.0b1", + "sapien>=3.0.0b1 ; sys_platform != 'darwin'", "wandb>=0.20.1", "torchinfo>=1.8.0", "debugpy>=1.8.14", diff --git a/reppo/env_utils/jax_wrappers.py b/src/env_utils/jax_wrappers.py similarity index 100% rename from reppo/env_utils/jax_wrappers.py rename to src/env_utils/jax_wrappers.py diff --git a/reppo/env_utils/torch_wrappers/humanoid_bench_env.py b/src/env_utils/torch_wrappers/humanoid_bench_env.py similarity index 100% rename from reppo/env_utils/torch_wrappers/humanoid_bench_env.py rename to src/env_utils/torch_wrappers/humanoid_bench_env.py diff --git a/reppo/env_utils/torch_wrappers/isaaclab_env.py b/src/env_utils/torch_wrappers/isaaclab_env.py similarity index 100% rename from reppo/env_utils/torch_wrappers/isaaclab_env.py rename to src/env_utils/torch_wrappers/isaaclab_env.py diff --git a/reppo/env_utils/torch_wrappers/maniskill_wrapper.py b/src/env_utils/torch_wrappers/maniskill_wrapper.py similarity index 100% rename from reppo/env_utils/torch_wrappers/maniskill_wrapper.py rename to src/env_utils/torch_wrappers/maniskill_wrapper.py diff --git a/reppo/env_utils/torch_wrappers/mtbench_env.py b/src/env_utils/torch_wrappers/mtbench_env.py similarity index 100% rename from reppo/env_utils/torch_wrappers/mtbench_env.py rename to src/env_utils/torch_wrappers/mtbench_env.py diff --git a/reppo/env_utils/torch_wrappers/mujoco_playground_env.py b/src/env_utils/torch_wrappers/mujoco_playground_env.py similarity index 100% rename from reppo/env_utils/torch_wrappers/mujoco_playground_env.py rename to src/env_utils/torch_wrappers/mujoco_playground_env.py diff --git a/reppo/jaxrl/__init__.py b/src/jaxrl/__init__.py similarity index 100% rename from reppo/jaxrl/__init__.py rename to src/jaxrl/__init__.py diff --git a/reppo/jaxrl/normalization.py b/src/jaxrl/normalization.py similarity index 100% rename from reppo/jaxrl/normalization.py rename to src/jaxrl/normalization.py diff --git a/reppo/jaxrl/ppo_mjx.py b/src/jaxrl/ppo_mjx.py similarity index 99% rename from reppo/jaxrl/ppo_mjx.py rename to src/jaxrl/ppo_mjx.py index e860701..0d55cd7 100644 --- a/reppo/jaxrl/ppo_mjx.py +++ b/src/jaxrl/ppo_mjx.py @@ -18,14 +18,14 @@ from jax.random import PRNGKey from omegaconf import DictConfig, OmegaConf import wandb -from reppo.env_utils.jax_wrappers import ( +from src.env_utils.jax_wrappers import ( BraxGymnaxWrapper, ClipAction, LogWrapper, MjxGymnaxWrapper, ) -from reppo.jaxrl import utils -from reppo.jaxrl.normalization import NormalizationState, Normalizer +from src.jaxrl import utils +from src.jaxrl.normalization import NormalizationState, Normalizer logging.basicConfig(level=logging.INFO) diff --git a/reppo/jaxrl/reppo.py b/src/jaxrl/reppo.py similarity index 99% rename from reppo/jaxrl/reppo.py rename to src/jaxrl/reppo.py index 7e38040..f10c16d 100644 --- a/reppo/jaxrl/reppo.py +++ b/src/jaxrl/reppo.py @@ -17,15 +17,15 @@ from jax.random import PRNGKey from omegaconf import DictConfig, OmegaConf import wandb -from reppo.env_utils.jax_wrappers import ( +from src.env_utils.jax_wrappers import ( BraxGymnaxWrapper, ClipAction, LogWrapper, MjxGymnaxWrapper, NormalizeVec, ) -from reppo.jaxrl import utils -from reppo.network_utils.jax_models import ( +from src.jaxrl import utils +from src.network_utils.jax_models import ( CategoricalCriticNetwork, CriticNetwork, SACActorNetworks, @@ -928,10 +928,8 @@ def run(cfg: DictConfig, trial: optuna.Trial | None) -> float: return (0.1 * sweep_metrics_array.mean() + sweep_metrics_array[:, -1].mean()).item() -@hydra.main(version_base=None, config_path="../../config", config_name="sac") +@hydra.main(version_base=None, config_path="../../config", config_name="reppo") def main(cfg: DictConfig): - cfg.hyperparameters = OmegaConf.merge(cfg.hyperparameters, cfg.experiment_overrides) - run(cfg, trial=None) diff --git a/reppo/jaxrl/utils.py b/src/jaxrl/utils.py similarity index 100% rename from reppo/jaxrl/utils.py rename to src/jaxrl/utils.py diff --git a/reppo/network_utils/fast_td3_nets.py b/src/network_utils/fast_td3_nets.py similarity index 100% rename from reppo/network_utils/fast_td3_nets.py rename to src/network_utils/fast_td3_nets.py diff --git a/reppo/network_utils/jax_models.py b/src/network_utils/jax_models.py similarity index 99% rename from reppo/network_utils/jax_models.py rename to src/network_utils/jax_models.py index 851e079..376fb29 100644 --- a/reppo/network_utils/jax_models.py +++ b/src/network_utils/jax_models.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp from flax import nnx -from reppo.jaxrl import utils +from src.jaxrl import utils def torch_he_uniform( diff --git a/reppo/network_utils/torch_models.py b/src/network_utils/torch_models.py similarity index 99% rename from reppo/network_utils/torch_models.py rename to src/network_utils/torch_models.py index 6a74dce..7ebd8e5 100644 --- a/reppo/network_utils/torch_models.py +++ b/src/network_utils/torch_models.py @@ -4,7 +4,7 @@ from torch.distributions import constraints from torch.distributions.transforms import Transform from torch.distributions.normal import Normal -from reppo.torchrl.reppo import hl_gauss +from src.torchrl.reppo import hl_gauss class TanhTransform(Transform): diff --git a/reppo/torchrl/envs.py b/src/torchrl/envs.py similarity index 100% rename from reppo/torchrl/envs.py rename to src/torchrl/envs.py diff --git a/reppo/torchrl/fast_td3.py b/src/torchrl/fast_td3.py similarity index 100% rename from reppo/torchrl/fast_td3.py rename to src/torchrl/fast_td3.py diff --git a/reppo/torchrl/hyperparams.py b/src/torchrl/hyperparams.py similarity index 100% rename from reppo/torchrl/hyperparams.py rename to src/torchrl/hyperparams.py diff --git a/reppo/torchrl/reppo.py b/src/torchrl/reppo.py similarity index 99% rename from reppo/torchrl/reppo.py rename to src/torchrl/reppo.py index 230e032..d0f269d 100644 --- a/reppo/torchrl/reppo.py +++ b/src/torchrl/reppo.py @@ -26,9 +26,9 @@ import torch.optim as optim from torchinfo import summary from tensordict import TensorDict from torch.amp import GradScaler -from reppo.torchrl.envs import make_envs -from reppo.network_utils.torch_models import Actor, Critic -from reppo.torchrl.reppo import ( +from src.torchrl.envs import make_envs +from src.network_utils.torch_models import Actor, Critic +from src.torchrl.reppo import ( EmpiricalNormalization, hl_gauss, ) diff --git a/reppo/torchrl/reppo_util.py b/src/torchrl/reppo_util.py similarity index 100% rename from reppo/torchrl/reppo_util.py rename to src/torchrl/reppo_util.py diff --git a/reppo/torchrl/tensordict_replay_buffer.py b/src/torchrl/tensordict_replay_buffer.py similarity index 100% rename from reppo/torchrl/tensordict_replay_buffer.py rename to src/torchrl/tensordict_replay_buffer.py