Fixes build errors due to name conflicts
This commit is contained in:
parent
094ee0c5ba
commit
e2f99648ae
@ -38,11 +38,12 @@ dependencies = [
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
requires = ["uv_build>=0.8.0,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["onpolicy_sac"]
|
||||
[tool.uv.build-backend]
|
||||
module-name = "reppo_alg"
|
||||
module-root = ""
|
||||
|
||||
[tool.ruff]
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
|
0
reppo_alg/__init__.py
Normal file
0
reppo_alg/__init__.py
Normal file
0
reppo_alg/env_utils/__init__.py
Normal file
0
reppo_alg/env_utils/__init__.py
Normal file
@ -3,11 +3,11 @@ from typing import Optional
|
||||
import gymnasium as gym
|
||||
import torch
|
||||
from isaaclab.app import AppLauncher
|
||||
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
|
||||
|
||||
app_launcher = AppLauncher(headless=True)
|
||||
simulation_app = app_launcher.app
|
||||
|
||||
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
|
||||
|
||||
|
||||
class IsaacLabEnv:
|
@ -1,6 +1,5 @@
|
||||
import jax
|
||||
from mujoco_playground import registry, wrapper_torch
|
||||
import torch
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
@ -18,14 +18,14 @@ from jax.random import PRNGKey
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import wandb
|
||||
from src.env_utils.jax_wrappers import (
|
||||
from reppo_alg.env_utils.jax_wrappers import (
|
||||
BraxGymnaxWrapper,
|
||||
ClipAction,
|
||||
LogWrapper,
|
||||
MjxGymnaxWrapper,
|
||||
)
|
||||
from src.jaxrl import utils
|
||||
from src.jaxrl.normalization import NormalizationState, Normalizer
|
||||
from reppo_alg.jaxrl import utils
|
||||
from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -17,15 +17,15 @@ from jax.random import PRNGKey
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import wandb
|
||||
from src.env_utils.jax_wrappers import (
|
||||
from reppo_alg.env_utils.jax_wrappers import (
|
||||
BraxGymnaxWrapper,
|
||||
ClipAction,
|
||||
LogWrapper,
|
||||
MjxGymnaxWrapper,
|
||||
NormalizeVec,
|
||||
)
|
||||
from src.jaxrl import utils
|
||||
from src.network_utils.jax_models import (
|
||||
from reppo_alg.jaxrl import utils, muon
|
||||
from reppo_alg.network_utils.jax_models import (
|
||||
CategoricalCriticNetwork,
|
||||
CriticNetwork,
|
||||
SACActorNetworks,
|
||||
@ -33,10 +33,6 @@ from src.network_utils.jax_models import (
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
import mujoco
|
||||
|
||||
print(mujoco.__file__)
|
||||
|
||||
|
||||
class Policy(typing.Protocol):
|
||||
def __call__(
|
||||
@ -242,14 +238,16 @@ def make_init(
|
||||
|
||||
if cfg.max_grad_norm is not None:
|
||||
actor_optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr)
|
||||
optax.clip_by_global_norm(cfg.max_grad_norm),
|
||||
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
|
||||
)
|
||||
critic_optimizer = optax.chain(
|
||||
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr)
|
||||
optax.clip_by_global_norm(cfg.max_grad_norm),
|
||||
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
|
||||
)
|
||||
else:
|
||||
actor_optimizer = optax.adam(lr)
|
||||
critic_optimizer = optax.adam(lr)
|
||||
actor_optimizer = muon.muon(lr) # optax.adam(lr)
|
||||
critic_optimizer = muon.muon(lr) # optax.adam(lr)
|
||||
|
||||
actor_trainstate = nnx.TrainState.create(
|
||||
graphdef=nnx.graphdef(actor_networks),
|
||||
@ -282,14 +280,6 @@ def make_init(
|
||||
).astype(jnp.float32)
|
||||
env_state.set_env_state(_env_state)
|
||||
|
||||
# mock_action = jnp.zeros(
|
||||
# (1, 6), dtype=jnp.float32
|
||||
# )
|
||||
# print(mock_action.shape)
|
||||
# print(obs.shape)
|
||||
# print(nnx.tabulate(critic_networks, obs[:1], mock_action))
|
||||
# print(nnx.tabulate(actor_networks, obs[:1]))
|
||||
|
||||
return SACTrainState(
|
||||
actor=actor_trainstate,
|
||||
actor_target=actor_target_trainstate,
|
||||
@ -318,7 +308,6 @@ def make_train_fn(
|
||||
# env = VecEnv(env, cfg.num_envs)
|
||||
if cfg.normalize_env:
|
||||
env = NormalizeVec(env)
|
||||
print(env)
|
||||
eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale)
|
||||
action_size_target = (
|
||||
jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult
|
||||
@ -452,12 +441,10 @@ def make_train_fn(
|
||||
reverse=True,
|
||||
)
|
||||
# Reshape data to (num_steps * num_envs, ...)
|
||||
jax.debug.print("num trunc {}", batch.truncated.sum(), ordered=True)
|
||||
data = (batch, target_values)
|
||||
data = jax.tree.map(
|
||||
lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data
|
||||
)
|
||||
# jax.debug.print("whole data {}", data[0].truncated.sum(), ordered=True)
|
||||
|
||||
train_state = train_state.replace(
|
||||
actor_target=train_state.actor_target.replace(
|
0
reppo_alg/network_utils/__init__.py
Normal file
0
reppo_alg/network_utils/__init__.py
Normal file
@ -52,13 +52,13 @@ class DistributionalQNetwork(nn.Module):
|
||||
)
|
||||
target_z = target_z.clamp(self.v_min, self.v_max)
|
||||
b = (target_z - self.v_min) / delta_z
|
||||
l = torch.floor(b).long()
|
||||
low = torch.floor(b).long()
|
||||
u = torch.ceil(b).long()
|
||||
|
||||
l_mask = torch.logical_and((u > 0), (l == u))
|
||||
u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u))
|
||||
l_mask = torch.logical_and((u > 0), (low == u))
|
||||
u_mask = torch.logical_and((low < (self.num_atoms - 1)), (low == u))
|
||||
|
||||
l = torch.where(l_mask, l - 1, l)
|
||||
low = torch.where(l_mask, low - 1, low)
|
||||
u = torch.where(u_mask, u + 1, u)
|
||||
|
||||
next_dist = F.softmax(self.forward(obs, actions), dim=1)
|
||||
@ -72,10 +72,10 @@ class DistributionalQNetwork(nn.Module):
|
||||
.long()
|
||||
)
|
||||
proj_dist.view(-1).index_add_(
|
||||
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
|
||||
0, (low + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
|
||||
)
|
||||
proj_dist.view(-1).index_add_(
|
||||
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
|
||||
0, (u + offset).view(-1), (next_dist * (b - low.float())).view(-1)
|
||||
)
|
||||
return proj_dist
|
||||
|
@ -6,7 +6,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import nnx
|
||||
|
||||
from src.jaxrl import utils
|
||||
from reppo_alg.jaxrl import utils
|
||||
|
||||
|
||||
def torch_he_uniform(
|
@ -4,7 +4,7 @@ from torch.distributions import constraints
|
||||
from torch.distributions.transforms import Transform
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
from src.torchrl.reppo_util import hl_gauss
|
||||
from reppo_alg.torchrl.reppo import hl_gauss
|
||||
|
||||
|
||||
class TanhTransform(Transform):
|
||||
@ -34,7 +34,9 @@ class TanhTransform(Transform):
|
||||
codomain = constraints.interval(-1.0, 1.0)
|
||||
bijective = True
|
||||
sign = +1
|
||||
log2 = torch.log(torch.tensor(2.0)).to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
log2 = torch.log(torch.tensor(2.0)).to(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TanhTransform)
|
0
reppo_alg/torchrl/__init__.py
Normal file
0
reppo_alg/torchrl/__init__.py
Normal file
@ -4,7 +4,7 @@ from omegaconf import DictConfig
|
||||
|
||||
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
if cfg.env.type == "humanoid_bench":
|
||||
from src.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
HumanoidBenchEnv,
|
||||
)
|
||||
|
||||
@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
)
|
||||
return envs, envs
|
||||
elif cfg.env.type == "isaaclab":
|
||||
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
|
||||
envs = IsaacLabEnv(
|
||||
cfg.env.name,
|
||||
@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
)
|
||||
return envs, envs
|
||||
elif cfg.env.type == "mjx":
|
||||
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
|
||||
# TODO: Check if re-using same envs for eval could reduce memory usage
|
||||
envs, eval_envs = make_env(
|
||||
@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
from mani_skill.utils import gym_utils
|
||||
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
from src.env_utils.torch_wrappers.maniskill_wrapper import (
|
||||
from reppo_alg.env_utils.torch_wrappers.maniskill_wrapper import (
|
||||
ManiSkillWrapper,
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from src.torchrl.reppo_util import (
|
||||
from reppo_alg.torchrl.reppo import (
|
||||
EmpiricalNormalization,
|
||||
PerTaskRewardNormalizer,
|
||||
RewardNormalizer,
|
||||
@ -42,10 +42,6 @@ from torch.amp import GradScaler, autocast
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
@ -90,7 +86,7 @@ def main():
|
||||
print(f"Using device: {device}")
|
||||
|
||||
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
|
||||
from src.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
HumanoidBenchEnv,
|
||||
)
|
||||
|
||||
@ -98,7 +94,7 @@ def main():
|
||||
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
|
||||
eval_envs = envs
|
||||
elif args.env_name.startswith("Isaac-"):
|
||||
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
|
||||
env_type = "isaaclab"
|
||||
envs = IsaacLabEnv(
|
||||
@ -110,14 +106,14 @@ def main():
|
||||
)
|
||||
eval_envs = envs
|
||||
elif args.env_name.startswith("MTBench-"):
|
||||
from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
||||
from reppo_alg.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
||||
|
||||
env_name = "-".join(args.env_name.split("-")[1:])
|
||||
env_type = "mtbench"
|
||||
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
||||
eval_envs = envs
|
||||
else:
|
||||
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
|
||||
# TODO: Check if re-using same envs for eval could reduce memory usage
|
||||
env_type = "mujoco_playground"
|
||||
@ -133,11 +129,11 @@ def main():
|
||||
)
|
||||
|
||||
n_act = envs.num_actions
|
||||
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
|
||||
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
|
||||
if envs.asymmetric_obs:
|
||||
n_critic_obs = (
|
||||
envs.num_privileged_obs
|
||||
if type(envs.num_privileged_obs) == int
|
||||
if isinstance(envs.num_privileged_obs, int)
|
||||
else envs.num_privileged_obs[0]
|
||||
)
|
||||
else:
|
||||
@ -198,7 +194,7 @@ def main():
|
||||
|
||||
if args.agent == "fasttd3":
|
||||
if env_type in ["mtbench"]:
|
||||
from src.network_utils.fast_td3_nets import (
|
||||
from reppo_alg.network_utils.fast_td3_nets import (
|
||||
MultiTaskActor,
|
||||
MultiTaskCritic,
|
||||
)
|
||||
@ -206,7 +202,7 @@ def main():
|
||||
actor_cls = MultiTaskActor
|
||||
critic_cls = MultiTaskCritic
|
||||
else:
|
||||
from src.network_utils.fast_td3_nets import Actor, Critic
|
||||
from reppo_alg.network_utils.fast_td3_nets import Actor, Critic
|
||||
|
||||
actor_cls = Actor
|
||||
critic_cls = Critic
|
||||
@ -214,7 +210,7 @@ def main():
|
||||
print("Using FastTD3")
|
||||
elif args.agent == "fasttd3_simbav2":
|
||||
if env_type in ["mtbench"]:
|
||||
from src.network_utils.fast_td3_nets_simbav2 import (
|
||||
from reppo_alg.network_utils.fast_td3_nets_simbav2 import (
|
||||
MultiTaskActor,
|
||||
MultiTaskCritic,
|
||||
)
|
||||
@ -222,7 +218,7 @@ def main():
|
||||
actor_cls = MultiTaskActor
|
||||
critic_cls = MultiTaskCritic
|
||||
else:
|
||||
from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
||||
from reppo_alg.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
||||
|
||||
actor_cls = Actor
|
||||
critic_cls = Critic
|
@ -153,7 +153,6 @@ def get_args():
|
||||
"h1hand-basketball-v0": H1HandBasketballArgs,
|
||||
"h1hand-window-v0": H1HandWindowArgs,
|
||||
"h1hand-package-v0": H1HandPackageArgs,
|
||||
"h1hand-truck-v0": H1HandTruckArgs,
|
||||
# MuJoCo Playground
|
||||
# NOTE: These tasks are not full list of MuJoCo Playground tasks
|
||||
"G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs,
|
||||
@ -275,13 +274,6 @@ class H1HandPackageArgs(HumanoidBenchArgs):
|
||||
v_max: float = 10000.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class H1HandTruckArgs(HumanoidBenchArgs):
|
||||
env_name: str = "h1hand-truck-v0"
|
||||
v_min: float = -1000.0
|
||||
v_max: float = 1000.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MuJoCoPlaygroundArgs(BaseArgs):
|
||||
# Default hyperparameters for many of Playground environments
|
||||
@ -292,6 +284,7 @@ class MuJoCoPlaygroundArgs(BaseArgs):
|
||||
num_eval_envs: int = 1024
|
||||
gamma: float = 0.99
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTBenchArgs(BaseArgs):
|
||||
# Default hyperparameters for MTBench
|
@ -12,7 +12,7 @@ from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import wandb
|
||||
|
||||
from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss
|
||||
from reppo_alg.torchrl.reppo import EmpiricalNormalization, hl_gauss
|
||||
|
||||
try:
|
||||
# Required for avoiding IsaacGym import error
|
||||
@ -28,13 +28,8 @@ import torch.optim as optim
|
||||
from torchinfo import summary
|
||||
from tensordict import TensorDict
|
||||
from torch.amp import GradScaler
|
||||
from src.torchrl.envs import make_envs
|
||||
from src.network_utils.torch_models import Actor, Critic
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
except ImportError:
|
||||
pass
|
||||
from reppo_alg.torchrl.envs import make_envs
|
||||
from reppo_alg.network_utils.torch_models import Actor, Critic
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
@ -484,11 +479,11 @@ def main(cfg):
|
||||
envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed)
|
||||
|
||||
n_act = envs.num_actions
|
||||
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
|
||||
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
|
||||
if envs.asymmetric_obs:
|
||||
n_critic_obs = (
|
||||
envs.num_privileged_obs
|
||||
if type(envs.num_privileged_obs) == int
|
||||
if isinstance(envs.num_privileged_obs, int)
|
||||
else envs.num_privileged_obs[0]
|
||||
)
|
||||
else:
|
0
reppo_alg/torchrl/tensordict_replay_buffer.py
Normal file
0
reppo_alg/torchrl/tensordict_replay_buffer.py
Normal file
Loading…
Reference in New Issue
Block a user