diff --git a/src/torchrl/tensordict_replay_buffer.py b/config/trial_spec/default.yaml similarity index 100% rename from src/torchrl/tensordict_replay_buffer.py rename to config/trial_spec/default.yaml diff --git a/pyproject.toml b/pyproject.toml index 61223ce..bf9ce7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,12 @@ dependencies = [ ] [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +requires = ["uv_build>=0.8.0,<0.9.0"] +build-backend = "uv_build" -[tool.hatch.build.targets.wheel] -packages = ["onpolicy_sac"] +[tool.uv.build-backend] +module-name = "reppo_alg" +module-root = "" [tool.ruff] # Exclude a variety of commonly ignored directories. diff --git a/reppo_alg/__init__.py b/reppo_alg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reppo_alg/env_utils/__init__.py b/reppo_alg/env_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/env_utils/jax_wrappers.py b/reppo_alg/env_utils/jax_wrappers.py similarity index 100% rename from src/env_utils/jax_wrappers.py rename to reppo_alg/env_utils/jax_wrappers.py diff --git a/src/env_utils/torch_wrappers/humanoid_bench_env.py b/reppo_alg/env_utils/torch_wrappers/humanoid_bench_env.py similarity index 100% rename from src/env_utils/torch_wrappers/humanoid_bench_env.py rename to reppo_alg/env_utils/torch_wrappers/humanoid_bench_env.py diff --git a/src/env_utils/torch_wrappers/isaaclab_env.py b/reppo_alg/env_utils/torch_wrappers/isaaclab_env.py similarity index 100% rename from src/env_utils/torch_wrappers/isaaclab_env.py rename to reppo_alg/env_utils/torch_wrappers/isaaclab_env.py index 3729d59..2e016b0 100644 --- a/src/env_utils/torch_wrappers/isaaclab_env.py +++ b/reppo_alg/env_utils/torch_wrappers/isaaclab_env.py @@ -3,11 +3,11 @@ from typing import Optional import gymnasium as gym import torch from isaaclab.app import AppLauncher +from isaaclab_tasks.utils.parse_cfg import parse_env_cfg app_launcher = AppLauncher(headless=True) simulation_app = app_launcher.app -from isaaclab_tasks.utils.parse_cfg import parse_env_cfg class IsaacLabEnv: diff --git a/src/env_utils/torch_wrappers/maniskill_wrapper.py b/reppo_alg/env_utils/torch_wrappers/maniskill_wrapper.py similarity index 100% rename from src/env_utils/torch_wrappers/maniskill_wrapper.py rename to reppo_alg/env_utils/torch_wrappers/maniskill_wrapper.py diff --git a/src/env_utils/torch_wrappers/mtbench_env.py b/reppo_alg/env_utils/torch_wrappers/mtbench_env.py similarity index 100% rename from src/env_utils/torch_wrappers/mtbench_env.py rename to reppo_alg/env_utils/torch_wrappers/mtbench_env.py diff --git a/src/env_utils/torch_wrappers/mujoco_playground_env.py b/reppo_alg/env_utils/torch_wrappers/mujoco_playground_env.py similarity index 99% rename from src/env_utils/torch_wrappers/mujoco_playground_env.py rename to reppo_alg/env_utils/torch_wrappers/mujoco_playground_env.py index e07fac9..e0cdd40 100644 --- a/src/env_utils/torch_wrappers/mujoco_playground_env.py +++ b/reppo_alg/env_utils/torch_wrappers/mujoco_playground_env.py @@ -1,6 +1,5 @@ import jax from mujoco_playground import registry, wrapper_torch -import torch jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) @@ -70,7 +69,7 @@ class RandomizeInitialWrapper(wrapper_torch.RSLRLBraxWrapper): self.key, self.env_state.info["steps"].shape, 0, 1000 ).astype(jax.numpy.float32) return obs, critic_obs - + def step(self, action): obs, reward, done, info = super().step(action) return obs, reward, done, done, info diff --git a/src/jaxrl/__init__.py b/reppo_alg/jaxrl/__init__.py similarity index 100% rename from src/jaxrl/__init__.py rename to reppo_alg/jaxrl/__init__.py diff --git a/src/jaxrl/normalization.py b/reppo_alg/jaxrl/normalization.py similarity index 100% rename from src/jaxrl/normalization.py rename to reppo_alg/jaxrl/normalization.py diff --git a/src/jaxrl/ppo_mjx.py b/reppo_alg/jaxrl/ppo_mjx.py similarity index 99% rename from src/jaxrl/ppo_mjx.py rename to reppo_alg/jaxrl/ppo_mjx.py index bbeb1b9..09eb3a6 100644 --- a/src/jaxrl/ppo_mjx.py +++ b/reppo_alg/jaxrl/ppo_mjx.py @@ -18,14 +18,14 @@ from jax.random import PRNGKey from omegaconf import DictConfig, OmegaConf import wandb -from src.env_utils.jax_wrappers import ( +from reppo_alg.env_utils.jax_wrappers import ( BraxGymnaxWrapper, ClipAction, LogWrapper, MjxGymnaxWrapper, ) -from src.jaxrl import utils -from src.jaxrl.normalization import NormalizationState, Normalizer +from reppo_alg.jaxrl import utils +from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer logging.basicConfig(level=logging.INFO) diff --git a/src/jaxrl/reppo.py b/reppo_alg/jaxrl/reppo.py similarity index 97% rename from src/jaxrl/reppo.py rename to reppo_alg/jaxrl/reppo.py index f10c16d..ecf14fb 100644 --- a/src/jaxrl/reppo.py +++ b/reppo_alg/jaxrl/reppo.py @@ -17,15 +17,15 @@ from jax.random import PRNGKey from omegaconf import DictConfig, OmegaConf import wandb -from src.env_utils.jax_wrappers import ( +from reppo_alg.env_utils.jax_wrappers import ( BraxGymnaxWrapper, ClipAction, LogWrapper, MjxGymnaxWrapper, NormalizeVec, ) -from src.jaxrl import utils -from src.network_utils.jax_models import ( +from reppo_alg.jaxrl import utils, muon +from reppo_alg.network_utils.jax_models import ( CategoricalCriticNetwork, CriticNetwork, SACActorNetworks, @@ -33,10 +33,6 @@ from src.network_utils.jax_models import ( logging.basicConfig(level=logging.INFO) -import mujoco - -print(mujoco.__file__) - class Policy(typing.Protocol): def __call__( @@ -242,14 +238,16 @@ def make_init( if cfg.max_grad_norm is not None: actor_optimizer = optax.chain( - optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr) + optax.clip_by_global_norm(cfg.max_grad_norm), + muon.muon(lr), # optax.adam(lr) optax.adam(lr) ) critic_optimizer = optax.chain( - optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr) + optax.clip_by_global_norm(cfg.max_grad_norm), + muon.muon(lr), # optax.adam(lr) optax.adam(lr) ) else: - actor_optimizer = optax.adam(lr) - critic_optimizer = optax.adam(lr) + actor_optimizer = muon.muon(lr) # optax.adam(lr) + critic_optimizer = muon.muon(lr) # optax.adam(lr) actor_trainstate = nnx.TrainState.create( graphdef=nnx.graphdef(actor_networks), @@ -282,14 +280,6 @@ def make_init( ).astype(jnp.float32) env_state.set_env_state(_env_state) - # mock_action = jnp.zeros( - # (1, 6), dtype=jnp.float32 - # ) - # print(mock_action.shape) - # print(obs.shape) - # print(nnx.tabulate(critic_networks, obs[:1], mock_action)) - # print(nnx.tabulate(actor_networks, obs[:1])) - return SACTrainState( actor=actor_trainstate, actor_target=actor_target_trainstate, @@ -318,7 +308,6 @@ def make_train_fn( # env = VecEnv(env, cfg.num_envs) if cfg.normalize_env: env = NormalizeVec(env) - print(env) eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale) action_size_target = ( jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult @@ -452,12 +441,10 @@ def make_train_fn( reverse=True, ) # Reshape data to (num_steps * num_envs, ...) - jax.debug.print("num trunc {}", batch.truncated.sum(), ordered=True) data = (batch, target_values) data = jax.tree.map( lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data ) - # jax.debug.print("whole data {}", data[0].truncated.sum(), ordered=True) train_state = train_state.replace( actor_target=train_state.actor_target.replace( diff --git a/src/jaxrl/utils.py b/reppo_alg/jaxrl/utils.py similarity index 100% rename from src/jaxrl/utils.py rename to reppo_alg/jaxrl/utils.py diff --git a/reppo_alg/network_utils/__init__.py b/reppo_alg/network_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network_utils/fast_td3_nets.py b/reppo_alg/network_utils/fast_td3_nets.py similarity index 95% rename from src/network_utils/fast_td3_nets.py rename to reppo_alg/network_utils/fast_td3_nets.py index b7d4ff1..1338a0c 100644 --- a/src/network_utils/fast_td3_nets.py +++ b/reppo_alg/network_utils/fast_td3_nets.py @@ -52,13 +52,13 @@ class DistributionalQNetwork(nn.Module): ) target_z = target_z.clamp(self.v_min, self.v_max) b = (target_z - self.v_min) / delta_z - l = torch.floor(b).long() + low = torch.floor(b).long() u = torch.ceil(b).long() - l_mask = torch.logical_and((u > 0), (l == u)) - u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u)) + l_mask = torch.logical_and((u > 0), (low == u)) + u_mask = torch.logical_and((low < (self.num_atoms - 1)), (low == u)) - l = torch.where(l_mask, l - 1, l) + low = torch.where(l_mask, low - 1, low) u = torch.where(u_mask, u + 1, u) next_dist = F.softmax(self.forward(obs, actions), dim=1) @@ -72,10 +72,10 @@ class DistributionalQNetwork(nn.Module): .long() ) proj_dist.view(-1).index_add_( - 0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1) + 0, (low + offset).view(-1), (next_dist * (u.float() - b)).view(-1) ) proj_dist.view(-1).index_add_( - 0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1) + 0, (u + offset).view(-1), (next_dist * (b - low.float())).view(-1) ) return proj_dist diff --git a/src/network_utils/jax_models.py b/reppo_alg/network_utils/jax_models.py similarity index 99% rename from src/network_utils/jax_models.py rename to reppo_alg/network_utils/jax_models.py index 376fb29..7897644 100644 --- a/src/network_utils/jax_models.py +++ b/reppo_alg/network_utils/jax_models.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp from flax import nnx -from src.jaxrl import utils +from reppo_alg.jaxrl import utils def torch_he_uniform( diff --git a/src/network_utils/torch_models.py b/reppo_alg/network_utils/torch_models.py similarity index 98% rename from src/network_utils/torch_models.py rename to reppo_alg/network_utils/torch_models.py index 8457812..37d2664 100644 --- a/src/network_utils/torch_models.py +++ b/reppo_alg/network_utils/torch_models.py @@ -4,7 +4,7 @@ from torch.distributions import constraints from torch.distributions.transforms import Transform from torch.distributions.normal import Normal -from src.torchrl.reppo_util import hl_gauss +from reppo_alg.torchrl.reppo import hl_gauss class TanhTransform(Transform): @@ -34,7 +34,9 @@ class TanhTransform(Transform): codomain = constraints.interval(-1.0, 1.0) bijective = True sign = +1 - log2 = torch.log(torch.tensor(2.0)).to("cuda" if torch.cuda.is_available() else "cpu") + log2 = torch.log(torch.tensor(2.0)).to( + "cuda" if torch.cuda.is_available() else "cpu" + ) def __eq__(self, other): return isinstance(other, TanhTransform) diff --git a/reppo_alg/torchrl/__init__.py b/reppo_alg/torchrl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/torchrl/envs.py b/reppo_alg/torchrl/envs.py similarity index 90% rename from src/torchrl/envs.py rename to reppo_alg/torchrl/envs.py index 14a0592..212b252 100644 --- a/src/torchrl/envs.py +++ b/reppo_alg/torchrl/envs.py @@ -4,7 +4,7 @@ from omegaconf import DictConfig def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: if cfg.env.type == "humanoid_bench": - from src.env_utils.torch_wrappers.humanoid_bench_env import ( + from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import ( HumanoidBenchEnv, ) @@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: ) return envs, envs elif cfg.env.type == "isaaclab": - from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv + from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv envs = IsaacLabEnv( cfg.env.name, @@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: ) return envs, envs elif cfg.env.type == "mjx": - from src.env_utils.torch_wrappers.mujoco_playground_env import make_env + from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env # TODO: Check if re-using same envs for eval could reduce memory usage envs, eval_envs = make_env( @@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: from mani_skill.utils import gym_utils from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv - from src.env_utils.torch_wrappers.maniskill_wrapper import ( + from reppo_alg.env_utils.torch_wrappers.maniskill_wrapper import ( ManiSkillWrapper, ) diff --git a/src/torchrl/fast_td3.py b/reppo_alg/torchrl/fast_td3.py similarity index 96% rename from src/torchrl/fast_td3.py rename to reppo_alg/torchrl/fast_td3.py index aacdb2a..2528f80 100644 --- a/src/torchrl/fast_td3.py +++ b/reppo_alg/torchrl/fast_td3.py @@ -29,7 +29,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from src.torchrl.reppo_util import ( +from reppo_alg.torchrl.reppo import ( EmpiricalNormalization, PerTaskRewardNormalizer, RewardNormalizer, @@ -42,10 +42,6 @@ from torch.amp import GradScaler, autocast torch.set_float32_matmul_precision("high") -try: - import jax.numpy as jnp -except ImportError: - pass def main(): @@ -90,7 +86,7 @@ def main(): print(f"Using device: {device}") if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"): - from src.env_utils.torch_wrappers.humanoid_bench_env import ( + from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import ( HumanoidBenchEnv, ) @@ -98,7 +94,7 @@ def main(): envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device) eval_envs = envs elif args.env_name.startswith("Isaac-"): - from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv + from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv env_type = "isaaclab" envs = IsaacLabEnv( @@ -110,14 +106,14 @@ def main(): ) eval_envs = envs elif args.env_name.startswith("MTBench-"): - from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv + from reppo_alg.env_utils.torch_wrappers.mtbench_env import MTBenchEnv env_name = "-".join(args.env_name.split("-")[1:]) env_type = "mtbench" envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed) eval_envs = envs else: - from src.env_utils.torch_wrappers.mujoco_playground_env import make_env + from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env # TODO: Check if re-using same envs for eval could reduce memory usage env_type = "mujoco_playground" @@ -133,11 +129,11 @@ def main(): ) n_act = envs.num_actions - n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0] + n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0] if envs.asymmetric_obs: n_critic_obs = ( envs.num_privileged_obs - if type(envs.num_privileged_obs) == int + if isinstance(envs.num_privileged_obs, int) else envs.num_privileged_obs[0] ) else: @@ -198,7 +194,7 @@ def main(): if args.agent == "fasttd3": if env_type in ["mtbench"]: - from src.network_utils.fast_td3_nets import ( + from reppo_alg.network_utils.fast_td3_nets import ( MultiTaskActor, MultiTaskCritic, ) @@ -206,7 +202,7 @@ def main(): actor_cls = MultiTaskActor critic_cls = MultiTaskCritic else: - from src.network_utils.fast_td3_nets import Actor, Critic + from reppo_alg.network_utils.fast_td3_nets import Actor, Critic actor_cls = Actor critic_cls = Critic @@ -214,7 +210,7 @@ def main(): print("Using FastTD3") elif args.agent == "fasttd3_simbav2": if env_type in ["mtbench"]: - from src.network_utils.fast_td3_nets_simbav2 import ( + from reppo_alg.network_utils.fast_td3_nets_simbav2 import ( MultiTaskActor, MultiTaskCritic, ) @@ -222,7 +218,7 @@ def main(): actor_cls = MultiTaskActor critic_cls = MultiTaskCritic else: - from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic + from reppo_alg.network_utils.fast_td3_nets_simbav2 import Actor, Critic actor_cls = Actor critic_cls = Critic diff --git a/src/torchrl/hyperparams.py b/reppo_alg/torchrl/hyperparams.py similarity index 98% rename from src/torchrl/hyperparams.py rename to reppo_alg/torchrl/hyperparams.py index fe6733b..bfa96f1 100644 --- a/src/torchrl/hyperparams.py +++ b/reppo_alg/torchrl/hyperparams.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import tyro -@dataclass +@dataclass class BaseArgs: # Default hyperparameters -- specifically for HumanoidBench # See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground @@ -153,7 +153,6 @@ def get_args(): "h1hand-basketball-v0": H1HandBasketballArgs, "h1hand-window-v0": H1HandWindowArgs, "h1hand-package-v0": H1HandPackageArgs, - "h1hand-truck-v0": H1HandTruckArgs, # MuJoCo Playground # NOTE: These tasks are not full list of MuJoCo Playground tasks "G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs, @@ -275,13 +274,6 @@ class H1HandPackageArgs(HumanoidBenchArgs): v_max: float = 10000.0 -@dataclass -class H1HandTruckArgs(HumanoidBenchArgs): - env_name: str = "h1hand-truck-v0" - v_min: float = -1000.0 - v_max: float = 1000.0 - - @dataclass class MuJoCoPlaygroundArgs(BaseArgs): # Default hyperparameters for many of Playground environments @@ -292,6 +284,7 @@ class MuJoCoPlaygroundArgs(BaseArgs): num_eval_envs: int = 1024 gamma: float = 0.99 + @dataclass class MTBenchArgs(BaseArgs): # Default hyperparameters for MTBench diff --git a/src/torchrl/reppo.py b/reppo_alg/torchrl/reppo.py similarity index 98% rename from src/torchrl/reppo.py rename to reppo_alg/torchrl/reppo.py index ddeae49..6015e57 100644 --- a/src/torchrl/reppo.py +++ b/reppo_alg/torchrl/reppo.py @@ -12,7 +12,7 @@ from omegaconf import DictConfig, OmegaConf import wandb -from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss +from reppo_alg.torchrl.reppo import EmpiricalNormalization, hl_gauss try: # Required for avoiding IsaacGym import error @@ -28,13 +28,8 @@ import torch.optim as optim from torchinfo import summary from tensordict import TensorDict from torch.amp import GradScaler -from src.torchrl.envs import make_envs -from src.network_utils.torch_models import Actor, Critic - -try: - import jax.numpy as jnp -except ImportError: - pass +from reppo_alg.torchrl.envs import make_envs +from reppo_alg.network_utils.torch_models import Actor, Critic torch.set_float32_matmul_precision("medium") @@ -484,11 +479,11 @@ def main(cfg): envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed) n_act = envs.num_actions - n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0] + n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0] if envs.asymmetric_obs: n_critic_obs = ( envs.num_privileged_obs - if type(envs.num_privileged_obs) == int + if isinstance(envs.num_privileged_obs, int) else envs.num_privileged_obs[0] ) else: diff --git a/src/torchrl/reppo_util.py b/reppo_alg/torchrl/reppo_util.py similarity index 100% rename from src/torchrl/reppo_util.py rename to reppo_alg/torchrl/reppo_util.py diff --git a/reppo_alg/torchrl/tensordict_replay_buffer.py b/reppo_alg/torchrl/tensordict_replay_buffer.py new file mode 100644 index 0000000..e69de29