Fixes build errors due to name conflicts

This commit is contained in:
cvoelcker 2025-07-21 18:17:03 -04:00
parent 094ee0c5ba
commit e2f99648ae
26 changed files with 52 additions and 79 deletions

View File

@ -38,11 +38,12 @@ dependencies = [
] ]
[build-system] [build-system]
requires = ["hatchling"] requires = ["uv_build>=0.8.0,<0.9.0"]
build-backend = "hatchling.build" build-backend = "uv_build"
[tool.hatch.build.targets.wheel] [tool.uv.build-backend]
packages = ["onpolicy_sac"] module-name = "reppo_alg"
module-root = ""
[tool.ruff] [tool.ruff]
# Exclude a variety of commonly ignored directories. # Exclude a variety of commonly ignored directories.

0
reppo_alg/__init__.py Normal file
View File

View File

View File

@ -3,11 +3,11 @@ from typing import Optional
import gymnasium as gym import gymnasium as gym
import torch import torch
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
app_launcher = AppLauncher(headless=True) app_launcher = AppLauncher(headless=True)
simulation_app = app_launcher.app simulation_app = app_launcher.app
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
class IsaacLabEnv: class IsaacLabEnv:

View File

@ -1,6 +1,5 @@
import jax import jax
from mujoco_playground import registry, wrapper_torch from mujoco_playground import registry, wrapper_torch
import torch
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
@ -70,7 +69,7 @@ class RandomizeInitialWrapper(wrapper_torch.RSLRLBraxWrapper):
self.key, self.env_state.info["steps"].shape, 0, 1000 self.key, self.env_state.info["steps"].shape, 0, 1000
).astype(jax.numpy.float32) ).astype(jax.numpy.float32)
return obs, critic_obs return obs, critic_obs
def step(self, action): def step(self, action):
obs, reward, done, info = super().step(action) obs, reward, done, info = super().step(action)
return obs, reward, done, done, info return obs, reward, done, done, info

View File

@ -18,14 +18,14 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
import wandb import wandb
from src.env_utils.jax_wrappers import ( from reppo_alg.env_utils.jax_wrappers import (
BraxGymnaxWrapper, BraxGymnaxWrapper,
ClipAction, ClipAction,
LogWrapper, LogWrapper,
MjxGymnaxWrapper, MjxGymnaxWrapper,
) )
from src.jaxrl import utils from reppo_alg.jaxrl import utils
from src.jaxrl.normalization import NormalizationState, Normalizer from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

View File

@ -17,15 +17,15 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
import wandb import wandb
from src.env_utils.jax_wrappers import ( from reppo_alg.env_utils.jax_wrappers import (
BraxGymnaxWrapper, BraxGymnaxWrapper,
ClipAction, ClipAction,
LogWrapper, LogWrapper,
MjxGymnaxWrapper, MjxGymnaxWrapper,
NormalizeVec, NormalizeVec,
) )
from src.jaxrl import utils from reppo_alg.jaxrl import utils, muon
from src.network_utils.jax_models import ( from reppo_alg.network_utils.jax_models import (
CategoricalCriticNetwork, CategoricalCriticNetwork,
CriticNetwork, CriticNetwork,
SACActorNetworks, SACActorNetworks,
@ -33,10 +33,6 @@ from src.network_utils.jax_models import (
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
import mujoco
print(mujoco.__file__)
class Policy(typing.Protocol): class Policy(typing.Protocol):
def __call__( def __call__(
@ -242,14 +238,16 @@ def make_init(
if cfg.max_grad_norm is not None: if cfg.max_grad_norm is not None:
actor_optimizer = optax.chain( actor_optimizer = optax.chain(
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr) optax.clip_by_global_norm(cfg.max_grad_norm),
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
) )
critic_optimizer = optax.chain( critic_optimizer = optax.chain(
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr) optax.clip_by_global_norm(cfg.max_grad_norm),
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
) )
else: else:
actor_optimizer = optax.adam(lr) actor_optimizer = muon.muon(lr) # optax.adam(lr)
critic_optimizer = optax.adam(lr) critic_optimizer = muon.muon(lr) # optax.adam(lr)
actor_trainstate = nnx.TrainState.create( actor_trainstate = nnx.TrainState.create(
graphdef=nnx.graphdef(actor_networks), graphdef=nnx.graphdef(actor_networks),
@ -282,14 +280,6 @@ def make_init(
).astype(jnp.float32) ).astype(jnp.float32)
env_state.set_env_state(_env_state) env_state.set_env_state(_env_state)
# mock_action = jnp.zeros(
# (1, 6), dtype=jnp.float32
# )
# print(mock_action.shape)
# print(obs.shape)
# print(nnx.tabulate(critic_networks, obs[:1], mock_action))
# print(nnx.tabulate(actor_networks, obs[:1]))
return SACTrainState( return SACTrainState(
actor=actor_trainstate, actor=actor_trainstate,
actor_target=actor_target_trainstate, actor_target=actor_target_trainstate,
@ -318,7 +308,6 @@ def make_train_fn(
# env = VecEnv(env, cfg.num_envs) # env = VecEnv(env, cfg.num_envs)
if cfg.normalize_env: if cfg.normalize_env:
env = NormalizeVec(env) env = NormalizeVec(env)
print(env)
eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale) eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale)
action_size_target = ( action_size_target = (
jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult
@ -452,12 +441,10 @@ def make_train_fn(
reverse=True, reverse=True,
) )
# Reshape data to (num_steps * num_envs, ...) # Reshape data to (num_steps * num_envs, ...)
jax.debug.print("num trunc {}", batch.truncated.sum(), ordered=True)
data = (batch, target_values) data = (batch, target_values)
data = jax.tree.map( data = jax.tree.map(
lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data
) )
# jax.debug.print("whole data {}", data[0].truncated.sum(), ordered=True)
train_state = train_state.replace( train_state = train_state.replace(
actor_target=train_state.actor_target.replace( actor_target=train_state.actor_target.replace(

View File

View File

@ -52,13 +52,13 @@ class DistributionalQNetwork(nn.Module):
) )
target_z = target_z.clamp(self.v_min, self.v_max) target_z = target_z.clamp(self.v_min, self.v_max)
b = (target_z - self.v_min) / delta_z b = (target_z - self.v_min) / delta_z
l = torch.floor(b).long() low = torch.floor(b).long()
u = torch.ceil(b).long() u = torch.ceil(b).long()
l_mask = torch.logical_and((u > 0), (l == u)) l_mask = torch.logical_and((u > 0), (low == u))
u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u)) u_mask = torch.logical_and((low < (self.num_atoms - 1)), (low == u))
l = torch.where(l_mask, l - 1, l) low = torch.where(l_mask, low - 1, low)
u = torch.where(u_mask, u + 1, u) u = torch.where(u_mask, u + 1, u)
next_dist = F.softmax(self.forward(obs, actions), dim=1) next_dist = F.softmax(self.forward(obs, actions), dim=1)
@ -72,10 +72,10 @@ class DistributionalQNetwork(nn.Module):
.long() .long()
) )
proj_dist.view(-1).index_add_( proj_dist.view(-1).index_add_(
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1) 0, (low + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
) )
proj_dist.view(-1).index_add_( proj_dist.view(-1).index_add_(
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1) 0, (u + offset).view(-1), (next_dist * (b - low.float())).view(-1)
) )
return proj_dist return proj_dist

View File

@ -6,7 +6,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax import nnx from flax import nnx
from src.jaxrl import utils from reppo_alg.jaxrl import utils
def torch_he_uniform( def torch_he_uniform(

View File

@ -4,7 +4,7 @@ from torch.distributions import constraints
from torch.distributions.transforms import Transform from torch.distributions.transforms import Transform
from torch.distributions.normal import Normal from torch.distributions.normal import Normal
from src.torchrl.reppo_util import hl_gauss from reppo_alg.torchrl.reppo import hl_gauss
class TanhTransform(Transform): class TanhTransform(Transform):
@ -34,7 +34,9 @@ class TanhTransform(Transform):
codomain = constraints.interval(-1.0, 1.0) codomain = constraints.interval(-1.0, 1.0)
bijective = True bijective = True
sign = +1 sign = +1
log2 = torch.log(torch.tensor(2.0)).to("cuda" if torch.cuda.is_available() else "cpu") log2 = torch.log(torch.tensor(2.0)).to(
"cuda" if torch.cuda.is_available() else "cpu"
)
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TanhTransform) return isinstance(other, TanhTransform)

View File

View File

@ -4,7 +4,7 @@ from omegaconf import DictConfig
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
if cfg.env.type == "humanoid_bench": if cfg.env.type == "humanoid_bench":
from src.env_utils.torch_wrappers.humanoid_bench_env import ( from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv, HumanoidBenchEnv,
) )
@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
) )
return envs, envs return envs, envs
elif cfg.env.type == "isaaclab": elif cfg.env.type == "isaaclab":
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
envs = IsaacLabEnv( envs = IsaacLabEnv(
cfg.env.name, cfg.env.name,
@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
) )
return envs, envs return envs, envs
elif cfg.env.type == "mjx": elif cfg.env.type == "mjx":
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage # TODO: Check if re-using same envs for eval could reduce memory usage
envs, eval_envs = make_env( envs, eval_envs = make_env(
@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
from mani_skill.utils import gym_utils from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from src.env_utils.torch_wrappers.maniskill_wrapper import ( from reppo_alg.env_utils.torch_wrappers.maniskill_wrapper import (
ManiSkillWrapper, ManiSkillWrapper,
) )

View File

@ -29,7 +29,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from src.torchrl.reppo_util import ( from reppo_alg.torchrl.reppo import (
EmpiricalNormalization, EmpiricalNormalization,
PerTaskRewardNormalizer, PerTaskRewardNormalizer,
RewardNormalizer, RewardNormalizer,
@ -42,10 +42,6 @@ from torch.amp import GradScaler, autocast
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
try:
import jax.numpy as jnp
except ImportError:
pass
def main(): def main():
@ -90,7 +86,7 @@ def main():
print(f"Using device: {device}") print(f"Using device: {device}")
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"): if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
from src.env_utils.torch_wrappers.humanoid_bench_env import ( from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv, HumanoidBenchEnv,
) )
@ -98,7 +94,7 @@ def main():
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device) envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
eval_envs = envs eval_envs = envs
elif args.env_name.startswith("Isaac-"): elif args.env_name.startswith("Isaac-"):
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
env_type = "isaaclab" env_type = "isaaclab"
envs = IsaacLabEnv( envs = IsaacLabEnv(
@ -110,14 +106,14 @@ def main():
) )
eval_envs = envs eval_envs = envs
elif args.env_name.startswith("MTBench-"): elif args.env_name.startswith("MTBench-"):
from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv from reppo_alg.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
env_name = "-".join(args.env_name.split("-")[1:]) env_name = "-".join(args.env_name.split("-")[1:])
env_type = "mtbench" env_type = "mtbench"
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed) envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
eval_envs = envs eval_envs = envs
else: else:
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage # TODO: Check if re-using same envs for eval could reduce memory usage
env_type = "mujoco_playground" env_type = "mujoco_playground"
@ -133,11 +129,11 @@ def main():
) )
n_act = envs.num_actions n_act = envs.num_actions
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0] n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
if envs.asymmetric_obs: if envs.asymmetric_obs:
n_critic_obs = ( n_critic_obs = (
envs.num_privileged_obs envs.num_privileged_obs
if type(envs.num_privileged_obs) == int if isinstance(envs.num_privileged_obs, int)
else envs.num_privileged_obs[0] else envs.num_privileged_obs[0]
) )
else: else:
@ -198,7 +194,7 @@ def main():
if args.agent == "fasttd3": if args.agent == "fasttd3":
if env_type in ["mtbench"]: if env_type in ["mtbench"]:
from src.network_utils.fast_td3_nets import ( from reppo_alg.network_utils.fast_td3_nets import (
MultiTaskActor, MultiTaskActor,
MultiTaskCritic, MultiTaskCritic,
) )
@ -206,7 +202,7 @@ def main():
actor_cls = MultiTaskActor actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic critic_cls = MultiTaskCritic
else: else:
from src.network_utils.fast_td3_nets import Actor, Critic from reppo_alg.network_utils.fast_td3_nets import Actor, Critic
actor_cls = Actor actor_cls = Actor
critic_cls = Critic critic_cls = Critic
@ -214,7 +210,7 @@ def main():
print("Using FastTD3") print("Using FastTD3")
elif args.agent == "fasttd3_simbav2": elif args.agent == "fasttd3_simbav2":
if env_type in ["mtbench"]: if env_type in ["mtbench"]:
from src.network_utils.fast_td3_nets_simbav2 import ( from reppo_alg.network_utils.fast_td3_nets_simbav2 import (
MultiTaskActor, MultiTaskActor,
MultiTaskCritic, MultiTaskCritic,
) )
@ -222,7 +218,7 @@ def main():
actor_cls = MultiTaskActor actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic critic_cls = MultiTaskCritic
else: else:
from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic from reppo_alg.network_utils.fast_td3_nets_simbav2 import Actor, Critic
actor_cls = Actor actor_cls = Actor
critic_cls = Critic critic_cls = Critic

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
import tyro import tyro
@dataclass @dataclass
class BaseArgs: class BaseArgs:
# Default hyperparameters -- specifically for HumanoidBench # Default hyperparameters -- specifically for HumanoidBench
# See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground # See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground
@ -153,7 +153,6 @@ def get_args():
"h1hand-basketball-v0": H1HandBasketballArgs, "h1hand-basketball-v0": H1HandBasketballArgs,
"h1hand-window-v0": H1HandWindowArgs, "h1hand-window-v0": H1HandWindowArgs,
"h1hand-package-v0": H1HandPackageArgs, "h1hand-package-v0": H1HandPackageArgs,
"h1hand-truck-v0": H1HandTruckArgs,
# MuJoCo Playground # MuJoCo Playground
# NOTE: These tasks are not full list of MuJoCo Playground tasks # NOTE: These tasks are not full list of MuJoCo Playground tasks
"G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs, "G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs,
@ -275,13 +274,6 @@ class H1HandPackageArgs(HumanoidBenchArgs):
v_max: float = 10000.0 v_max: float = 10000.0
@dataclass
class H1HandTruckArgs(HumanoidBenchArgs):
env_name: str = "h1hand-truck-v0"
v_min: float = -1000.0
v_max: float = 1000.0
@dataclass @dataclass
class MuJoCoPlaygroundArgs(BaseArgs): class MuJoCoPlaygroundArgs(BaseArgs):
# Default hyperparameters for many of Playground environments # Default hyperparameters for many of Playground environments
@ -292,6 +284,7 @@ class MuJoCoPlaygroundArgs(BaseArgs):
num_eval_envs: int = 1024 num_eval_envs: int = 1024
gamma: float = 0.99 gamma: float = 0.99
@dataclass @dataclass
class MTBenchArgs(BaseArgs): class MTBenchArgs(BaseArgs):
# Default hyperparameters for MTBench # Default hyperparameters for MTBench

View File

@ -12,7 +12,7 @@ from omegaconf import DictConfig, OmegaConf
import wandb import wandb
from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss from reppo_alg.torchrl.reppo import EmpiricalNormalization, hl_gauss
try: try:
# Required for avoiding IsaacGym import error # Required for avoiding IsaacGym import error
@ -28,13 +28,8 @@ import torch.optim as optim
from torchinfo import summary from torchinfo import summary
from tensordict import TensorDict from tensordict import TensorDict
from torch.amp import GradScaler from torch.amp import GradScaler
from src.torchrl.envs import make_envs from reppo_alg.torchrl.envs import make_envs
from src.network_utils.torch_models import Actor, Critic from reppo_alg.network_utils.torch_models import Actor, Critic
try:
import jax.numpy as jnp
except ImportError:
pass
torch.set_float32_matmul_precision("medium") torch.set_float32_matmul_precision("medium")
@ -484,11 +479,11 @@ def main(cfg):
envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed) envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed)
n_act = envs.num_actions n_act = envs.num_actions
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0] n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
if envs.asymmetric_obs: if envs.asymmetric_obs:
n_critic_obs = ( n_critic_obs = (
envs.num_privileged_obs envs.num_privileged_obs
if type(envs.num_privileged_obs) == int if isinstance(envs.num_privileged_obs, int)
else envs.num_privileged_obs[0] else envs.num_privileged_obs[0]
) )
else: else: