Fixes build errors due to name conflicts

This commit is contained in:
cvoelcker 2025-07-21 18:17:03 -04:00
parent 094ee0c5ba
commit e2f99648ae
26 changed files with 52 additions and 79 deletions

View File

@ -38,11 +38,12 @@ dependencies = [
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
requires = ["uv_build>=0.8.0,<0.9.0"]
build-backend = "uv_build"
[tool.hatch.build.targets.wheel]
packages = ["onpolicy_sac"]
[tool.uv.build-backend]
module-name = "reppo_alg"
module-root = ""
[tool.ruff]
# Exclude a variety of commonly ignored directories.

0
reppo_alg/__init__.py Normal file
View File

View File

View File

@ -3,11 +3,11 @@ from typing import Optional
import gymnasium as gym
import torch
from isaaclab.app import AppLauncher
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
app_launcher = AppLauncher(headless=True)
simulation_app = app_launcher.app
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
class IsaacLabEnv:

View File

@ -1,6 +1,5 @@
import jax
from mujoco_playground import registry, wrapper_torch
import torch
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)

View File

@ -18,14 +18,14 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf
import wandb
from src.env_utils.jax_wrappers import (
from reppo_alg.env_utils.jax_wrappers import (
BraxGymnaxWrapper,
ClipAction,
LogWrapper,
MjxGymnaxWrapper,
)
from src.jaxrl import utils
from src.jaxrl.normalization import NormalizationState, Normalizer
from reppo_alg.jaxrl import utils
from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer
logging.basicConfig(level=logging.INFO)

View File

@ -17,15 +17,15 @@ from jax.random import PRNGKey
from omegaconf import DictConfig, OmegaConf
import wandb
from src.env_utils.jax_wrappers import (
from reppo_alg.env_utils.jax_wrappers import (
BraxGymnaxWrapper,
ClipAction,
LogWrapper,
MjxGymnaxWrapper,
NormalizeVec,
)
from src.jaxrl import utils
from src.network_utils.jax_models import (
from reppo_alg.jaxrl import utils, muon
from reppo_alg.network_utils.jax_models import (
CategoricalCriticNetwork,
CriticNetwork,
SACActorNetworks,
@ -33,10 +33,6 @@ from src.network_utils.jax_models import (
logging.basicConfig(level=logging.INFO)
import mujoco
print(mujoco.__file__)
class Policy(typing.Protocol):
def __call__(
@ -242,14 +238,16 @@ def make_init(
if cfg.max_grad_norm is not None:
actor_optimizer = optax.chain(
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr)
optax.clip_by_global_norm(cfg.max_grad_norm),
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
)
critic_optimizer = optax.chain(
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr)
optax.clip_by_global_norm(cfg.max_grad_norm),
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
)
else:
actor_optimizer = optax.adam(lr)
critic_optimizer = optax.adam(lr)
actor_optimizer = muon.muon(lr) # optax.adam(lr)
critic_optimizer = muon.muon(lr) # optax.adam(lr)
actor_trainstate = nnx.TrainState.create(
graphdef=nnx.graphdef(actor_networks),
@ -282,14 +280,6 @@ def make_init(
).astype(jnp.float32)
env_state.set_env_state(_env_state)
# mock_action = jnp.zeros(
# (1, 6), dtype=jnp.float32
# )
# print(mock_action.shape)
# print(obs.shape)
# print(nnx.tabulate(critic_networks, obs[:1], mock_action))
# print(nnx.tabulate(actor_networks, obs[:1]))
return SACTrainState(
actor=actor_trainstate,
actor_target=actor_target_trainstate,
@ -318,7 +308,6 @@ def make_train_fn(
# env = VecEnv(env, cfg.num_envs)
if cfg.normalize_env:
env = NormalizeVec(env)
print(env)
eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale)
action_size_target = (
jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult
@ -452,12 +441,10 @@ def make_train_fn(
reverse=True,
)
# Reshape data to (num_steps * num_envs, ...)
jax.debug.print("num trunc {}", batch.truncated.sum(), ordered=True)
data = (batch, target_values)
data = jax.tree.map(
lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data
)
# jax.debug.print("whole data {}", data[0].truncated.sum(), ordered=True)
train_state = train_state.replace(
actor_target=train_state.actor_target.replace(

View File

View File

@ -52,13 +52,13 @@ class DistributionalQNetwork(nn.Module):
)
target_z = target_z.clamp(self.v_min, self.v_max)
b = (target_z - self.v_min) / delta_z
l = torch.floor(b).long()
low = torch.floor(b).long()
u = torch.ceil(b).long()
l_mask = torch.logical_and((u > 0), (l == u))
u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u))
l_mask = torch.logical_and((u > 0), (low == u))
u_mask = torch.logical_and((low < (self.num_atoms - 1)), (low == u))
l = torch.where(l_mask, l - 1, l)
low = torch.where(l_mask, low - 1, low)
u = torch.where(u_mask, u + 1, u)
next_dist = F.softmax(self.forward(obs, actions), dim=1)
@ -72,10 +72,10 @@ class DistributionalQNetwork(nn.Module):
.long()
)
proj_dist.view(-1).index_add_(
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
0, (low + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
)
proj_dist.view(-1).index_add_(
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
0, (u + offset).view(-1), (next_dist * (b - low.float())).view(-1)
)
return proj_dist

View File

@ -6,7 +6,7 @@ import jax
import jax.numpy as jnp
from flax import nnx
from src.jaxrl import utils
from reppo_alg.jaxrl import utils
def torch_he_uniform(

View File

@ -4,7 +4,7 @@ from torch.distributions import constraints
from torch.distributions.transforms import Transform
from torch.distributions.normal import Normal
from src.torchrl.reppo_util import hl_gauss
from reppo_alg.torchrl.reppo import hl_gauss
class TanhTransform(Transform):
@ -34,7 +34,9 @@ class TanhTransform(Transform):
codomain = constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
log2 = torch.log(torch.tensor(2.0)).to("cuda" if torch.cuda.is_available() else "cpu")
log2 = torch.log(torch.tensor(2.0)).to(
"cuda" if torch.cuda.is_available() else "cpu"
)
def __eq__(self, other):
return isinstance(other, TanhTransform)

View File

View File

@ -4,7 +4,7 @@ from omegaconf import DictConfig
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
if cfg.env.type == "humanoid_bench":
from src.env_utils.torch_wrappers.humanoid_bench_env import (
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv,
)
@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
)
return envs, envs
elif cfg.env.type == "isaaclab":
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
envs = IsaacLabEnv(
cfg.env.name,
@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
)
return envs, envs
elif cfg.env.type == "mjx":
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
envs, eval_envs = make_env(
@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from src.env_utils.torch_wrappers.maniskill_wrapper import (
from reppo_alg.env_utils.torch_wrappers.maniskill_wrapper import (
ManiSkillWrapper,
)

View File

@ -29,7 +29,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from src.torchrl.reppo_util import (
from reppo_alg.torchrl.reppo import (
EmpiricalNormalization,
PerTaskRewardNormalizer,
RewardNormalizer,
@ -42,10 +42,6 @@ from torch.amp import GradScaler, autocast
torch.set_float32_matmul_precision("high")
try:
import jax.numpy as jnp
except ImportError:
pass
def main():
@ -90,7 +86,7 @@ def main():
print(f"Using device: {device}")
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
from src.env_utils.torch_wrappers.humanoid_bench_env import (
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv,
)
@ -98,7 +94,7 @@ def main():
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
eval_envs = envs
elif args.env_name.startswith("Isaac-"):
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
env_type = "isaaclab"
envs = IsaacLabEnv(
@ -110,14 +106,14 @@ def main():
)
eval_envs = envs
elif args.env_name.startswith("MTBench-"):
from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
from reppo_alg.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
env_name = "-".join(args.env_name.split("-")[1:])
env_type = "mtbench"
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
eval_envs = envs
else:
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
env_type = "mujoco_playground"
@ -133,11 +129,11 @@ def main():
)
n_act = envs.num_actions
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
if envs.asymmetric_obs:
n_critic_obs = (
envs.num_privileged_obs
if type(envs.num_privileged_obs) == int
if isinstance(envs.num_privileged_obs, int)
else envs.num_privileged_obs[0]
)
else:
@ -198,7 +194,7 @@ def main():
if args.agent == "fasttd3":
if env_type in ["mtbench"]:
from src.network_utils.fast_td3_nets import (
from reppo_alg.network_utils.fast_td3_nets import (
MultiTaskActor,
MultiTaskCritic,
)
@ -206,7 +202,7 @@ def main():
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from src.network_utils.fast_td3_nets import Actor, Critic
from reppo_alg.network_utils.fast_td3_nets import Actor, Critic
actor_cls = Actor
critic_cls = Critic
@ -214,7 +210,7 @@ def main():
print("Using FastTD3")
elif args.agent == "fasttd3_simbav2":
if env_type in ["mtbench"]:
from src.network_utils.fast_td3_nets_simbav2 import (
from reppo_alg.network_utils.fast_td3_nets_simbav2 import (
MultiTaskActor,
MultiTaskCritic,
)
@ -222,7 +218,7 @@ def main():
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic
from reppo_alg.network_utils.fast_td3_nets_simbav2 import Actor, Critic
actor_cls = Actor
critic_cls = Critic

View File

@ -153,7 +153,6 @@ def get_args():
"h1hand-basketball-v0": H1HandBasketballArgs,
"h1hand-window-v0": H1HandWindowArgs,
"h1hand-package-v0": H1HandPackageArgs,
"h1hand-truck-v0": H1HandTruckArgs,
# MuJoCo Playground
# NOTE: These tasks are not full list of MuJoCo Playground tasks
"G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs,
@ -275,13 +274,6 @@ class H1HandPackageArgs(HumanoidBenchArgs):
v_max: float = 10000.0
@dataclass
class H1HandTruckArgs(HumanoidBenchArgs):
env_name: str = "h1hand-truck-v0"
v_min: float = -1000.0
v_max: float = 1000.0
@dataclass
class MuJoCoPlaygroundArgs(BaseArgs):
# Default hyperparameters for many of Playground environments
@ -292,6 +284,7 @@ class MuJoCoPlaygroundArgs(BaseArgs):
num_eval_envs: int = 1024
gamma: float = 0.99
@dataclass
class MTBenchArgs(BaseArgs):
# Default hyperparameters for MTBench

View File

@ -12,7 +12,7 @@ from omegaconf import DictConfig, OmegaConf
import wandb
from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss
from reppo_alg.torchrl.reppo import EmpiricalNormalization, hl_gauss
try:
# Required for avoiding IsaacGym import error
@ -28,13 +28,8 @@ import torch.optim as optim
from torchinfo import summary
from tensordict import TensorDict
from torch.amp import GradScaler
from src.torchrl.envs import make_envs
from src.network_utils.torch_models import Actor, Critic
try:
import jax.numpy as jnp
except ImportError:
pass
from reppo_alg.torchrl.envs import make_envs
from reppo_alg.network_utils.torch_models import Actor, Critic
torch.set_float32_matmul_precision("medium")
@ -484,11 +479,11 @@ def main(cfg):
envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed)
n_act = envs.num_actions
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
if envs.asymmetric_obs:
n_critic_obs = (
envs.num_privileged_obs
if type(envs.num_privileged_obs) == int
if isinstance(envs.num_privileged_obs, int)
else envs.num_privileged_obs[0]
)
else: