Fixes build errors due to name conflicts
This commit is contained in:
parent
094ee0c5ba
commit
e2f99648ae
@ -38,11 +38,12 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["uv_build>=0.8.0,<0.9.0"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.uv.build-backend]
|
||||||
packages = ["onpolicy_sac"]
|
module-name = "reppo_alg"
|
||||||
|
module-root = ""
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
# Exclude a variety of commonly ignored directories.
|
# Exclude a variety of commonly ignored directories.
|
||||||
|
0
reppo_alg/__init__.py
Normal file
0
reppo_alg/__init__.py
Normal file
0
reppo_alg/env_utils/__init__.py
Normal file
0
reppo_alg/env_utils/__init__.py
Normal file
@ -3,11 +3,11 @@ from typing import Optional
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import torch
|
import torch
|
||||||
from isaaclab.app import AppLauncher
|
from isaaclab.app import AppLauncher
|
||||||
|
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
|
||||||
|
|
||||||
app_launcher = AppLauncher(headless=True)
|
app_launcher = AppLauncher(headless=True)
|
||||||
simulation_app = app_launcher.app
|
simulation_app = app_launcher.app
|
||||||
|
|
||||||
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
|
|
||||||
|
|
||||||
|
|
||||||
class IsaacLabEnv:
|
class IsaacLabEnv:
|
@ -1,6 +1,5 @@
|
|||||||
import jax
|
import jax
|
||||||
from mujoco_playground import registry, wrapper_torch
|
from mujoco_playground import registry, wrapper_torch
|
||||||
import torch
|
|
||||||
|
|
||||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
@ -70,7 +69,7 @@ class RandomizeInitialWrapper(wrapper_torch.RSLRLBraxWrapper):
|
|||||||
self.key, self.env_state.info["steps"].shape, 0, 1000
|
self.key, self.env_state.info["steps"].shape, 0, 1000
|
||||||
).astype(jax.numpy.float32)
|
).astype(jax.numpy.float32)
|
||||||
return obs, critic_obs
|
return obs, critic_obs
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = super().step(action)
|
obs, reward, done, info = super().step(action)
|
||||||
return obs, reward, done, done, info
|
return obs, reward, done, done, info
|
@ -18,14 +18,14 @@ from jax.random import PRNGKey
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.env_utils.jax_wrappers import (
|
from reppo_alg.env_utils.jax_wrappers import (
|
||||||
BraxGymnaxWrapper,
|
BraxGymnaxWrapper,
|
||||||
ClipAction,
|
ClipAction,
|
||||||
LogWrapper,
|
LogWrapper,
|
||||||
MjxGymnaxWrapper,
|
MjxGymnaxWrapper,
|
||||||
)
|
)
|
||||||
from src.jaxrl import utils
|
from reppo_alg.jaxrl import utils
|
||||||
from src.jaxrl.normalization import NormalizationState, Normalizer
|
from reppo_alg.jaxrl.normalization import NormalizationState, Normalizer
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
@ -17,15 +17,15 @@ from jax.random import PRNGKey
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.env_utils.jax_wrappers import (
|
from reppo_alg.env_utils.jax_wrappers import (
|
||||||
BraxGymnaxWrapper,
|
BraxGymnaxWrapper,
|
||||||
ClipAction,
|
ClipAction,
|
||||||
LogWrapper,
|
LogWrapper,
|
||||||
MjxGymnaxWrapper,
|
MjxGymnaxWrapper,
|
||||||
NormalizeVec,
|
NormalizeVec,
|
||||||
)
|
)
|
||||||
from src.jaxrl import utils
|
from reppo_alg.jaxrl import utils, muon
|
||||||
from src.network_utils.jax_models import (
|
from reppo_alg.network_utils.jax_models import (
|
||||||
CategoricalCriticNetwork,
|
CategoricalCriticNetwork,
|
||||||
CriticNetwork,
|
CriticNetwork,
|
||||||
SACActorNetworks,
|
SACActorNetworks,
|
||||||
@ -33,10 +33,6 @@ from src.network_utils.jax_models import (
|
|||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
import mujoco
|
|
||||||
|
|
||||||
print(mujoco.__file__)
|
|
||||||
|
|
||||||
|
|
||||||
class Policy(typing.Protocol):
|
class Policy(typing.Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -242,14 +238,16 @@ def make_init(
|
|||||||
|
|
||||||
if cfg.max_grad_norm is not None:
|
if cfg.max_grad_norm is not None:
|
||||||
actor_optimizer = optax.chain(
|
actor_optimizer = optax.chain(
|
||||||
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr)
|
optax.clip_by_global_norm(cfg.max_grad_norm),
|
||||||
|
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
|
||||||
)
|
)
|
||||||
critic_optimizer = optax.chain(
|
critic_optimizer = optax.chain(
|
||||||
optax.clip_by_global_norm(cfg.max_grad_norm), optax.adam(lr)
|
optax.clip_by_global_norm(cfg.max_grad_norm),
|
||||||
|
muon.muon(lr), # optax.adam(lr) optax.adam(lr)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
actor_optimizer = optax.adam(lr)
|
actor_optimizer = muon.muon(lr) # optax.adam(lr)
|
||||||
critic_optimizer = optax.adam(lr)
|
critic_optimizer = muon.muon(lr) # optax.adam(lr)
|
||||||
|
|
||||||
actor_trainstate = nnx.TrainState.create(
|
actor_trainstate = nnx.TrainState.create(
|
||||||
graphdef=nnx.graphdef(actor_networks),
|
graphdef=nnx.graphdef(actor_networks),
|
||||||
@ -282,14 +280,6 @@ def make_init(
|
|||||||
).astype(jnp.float32)
|
).astype(jnp.float32)
|
||||||
env_state.set_env_state(_env_state)
|
env_state.set_env_state(_env_state)
|
||||||
|
|
||||||
# mock_action = jnp.zeros(
|
|
||||||
# (1, 6), dtype=jnp.float32
|
|
||||||
# )
|
|
||||||
# print(mock_action.shape)
|
|
||||||
# print(obs.shape)
|
|
||||||
# print(nnx.tabulate(critic_networks, obs[:1], mock_action))
|
|
||||||
# print(nnx.tabulate(actor_networks, obs[:1]))
|
|
||||||
|
|
||||||
return SACTrainState(
|
return SACTrainState(
|
||||||
actor=actor_trainstate,
|
actor=actor_trainstate,
|
||||||
actor_target=actor_target_trainstate,
|
actor_target=actor_target_trainstate,
|
||||||
@ -318,7 +308,6 @@ def make_train_fn(
|
|||||||
# env = VecEnv(env, cfg.num_envs)
|
# env = VecEnv(env, cfg.num_envs)
|
||||||
if cfg.normalize_env:
|
if cfg.normalize_env:
|
||||||
env = NormalizeVec(env)
|
env = NormalizeVec(env)
|
||||||
print(env)
|
|
||||||
eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale)
|
eval_fn = make_eval_fn(env, cfg.max_episode_steps, reward_scale=reward_scale)
|
||||||
action_size_target = (
|
action_size_target = (
|
||||||
jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult
|
jnp.prod(jnp.array(env.action_space(env_params).shape)) * cfg.ent_target_mult
|
||||||
@ -452,12 +441,10 @@ def make_train_fn(
|
|||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
# Reshape data to (num_steps * num_envs, ...)
|
# Reshape data to (num_steps * num_envs, ...)
|
||||||
jax.debug.print("num trunc {}", batch.truncated.sum(), ordered=True)
|
|
||||||
data = (batch, target_values)
|
data = (batch, target_values)
|
||||||
data = jax.tree.map(
|
data = jax.tree.map(
|
||||||
lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data
|
lambda x: x.reshape((cfg.num_steps * cfg.num_envs, *x.shape[2:])), data
|
||||||
)
|
)
|
||||||
# jax.debug.print("whole data {}", data[0].truncated.sum(), ordered=True)
|
|
||||||
|
|
||||||
train_state = train_state.replace(
|
train_state = train_state.replace(
|
||||||
actor_target=train_state.actor_target.replace(
|
actor_target=train_state.actor_target.replace(
|
0
reppo_alg/network_utils/__init__.py
Normal file
0
reppo_alg/network_utils/__init__.py
Normal file
@ -52,13 +52,13 @@ class DistributionalQNetwork(nn.Module):
|
|||||||
)
|
)
|
||||||
target_z = target_z.clamp(self.v_min, self.v_max)
|
target_z = target_z.clamp(self.v_min, self.v_max)
|
||||||
b = (target_z - self.v_min) / delta_z
|
b = (target_z - self.v_min) / delta_z
|
||||||
l = torch.floor(b).long()
|
low = torch.floor(b).long()
|
||||||
u = torch.ceil(b).long()
|
u = torch.ceil(b).long()
|
||||||
|
|
||||||
l_mask = torch.logical_and((u > 0), (l == u))
|
l_mask = torch.logical_and((u > 0), (low == u))
|
||||||
u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u))
|
u_mask = torch.logical_and((low < (self.num_atoms - 1)), (low == u))
|
||||||
|
|
||||||
l = torch.where(l_mask, l - 1, l)
|
low = torch.where(l_mask, low - 1, low)
|
||||||
u = torch.where(u_mask, u + 1, u)
|
u = torch.where(u_mask, u + 1, u)
|
||||||
|
|
||||||
next_dist = F.softmax(self.forward(obs, actions), dim=1)
|
next_dist = F.softmax(self.forward(obs, actions), dim=1)
|
||||||
@ -72,10 +72,10 @@ class DistributionalQNetwork(nn.Module):
|
|||||||
.long()
|
.long()
|
||||||
)
|
)
|
||||||
proj_dist.view(-1).index_add_(
|
proj_dist.view(-1).index_add_(
|
||||||
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
|
0, (low + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
|
||||||
)
|
)
|
||||||
proj_dist.view(-1).index_add_(
|
proj_dist.view(-1).index_add_(
|
||||||
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
|
0, (u + offset).view(-1), (next_dist * (b - low.float())).view(-1)
|
||||||
)
|
)
|
||||||
return proj_dist
|
return proj_dist
|
||||||
|
|
@ -6,7 +6,7 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax import nnx
|
from flax import nnx
|
||||||
|
|
||||||
from src.jaxrl import utils
|
from reppo_alg.jaxrl import utils
|
||||||
|
|
||||||
|
|
||||||
def torch_he_uniform(
|
def torch_he_uniform(
|
@ -4,7 +4,7 @@ from torch.distributions import constraints
|
|||||||
from torch.distributions.transforms import Transform
|
from torch.distributions.transforms import Transform
|
||||||
from torch.distributions.normal import Normal
|
from torch.distributions.normal import Normal
|
||||||
|
|
||||||
from src.torchrl.reppo_util import hl_gauss
|
from reppo_alg.torchrl.reppo import hl_gauss
|
||||||
|
|
||||||
|
|
||||||
class TanhTransform(Transform):
|
class TanhTransform(Transform):
|
||||||
@ -34,7 +34,9 @@ class TanhTransform(Transform):
|
|||||||
codomain = constraints.interval(-1.0, 1.0)
|
codomain = constraints.interval(-1.0, 1.0)
|
||||||
bijective = True
|
bijective = True
|
||||||
sign = +1
|
sign = +1
|
||||||
log2 = torch.log(torch.tensor(2.0)).to("cuda" if torch.cuda.is_available() else "cpu")
|
log2 = torch.log(torch.tensor(2.0)).to(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, TanhTransform)
|
return isinstance(other, TanhTransform)
|
0
reppo_alg/torchrl/__init__.py
Normal file
0
reppo_alg/torchrl/__init__.py
Normal file
@ -4,7 +4,7 @@ from omegaconf import DictConfig
|
|||||||
|
|
||||||
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||||
if cfg.env.type == "humanoid_bench":
|
if cfg.env.type == "humanoid_bench":
|
||||||
from src.env_utils.torch_wrappers.humanoid_bench_env import (
|
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||||
HumanoidBenchEnv,
|
HumanoidBenchEnv,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
|||||||
)
|
)
|
||||||
return envs, envs
|
return envs, envs
|
||||||
elif cfg.env.type == "isaaclab":
|
elif cfg.env.type == "isaaclab":
|
||||||
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||||
|
|
||||||
envs = IsaacLabEnv(
|
envs = IsaacLabEnv(
|
||||||
cfg.env.name,
|
cfg.env.name,
|
||||||
@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
|||||||
)
|
)
|
||||||
return envs, envs
|
return envs, envs
|
||||||
elif cfg.env.type == "mjx":
|
elif cfg.env.type == "mjx":
|
||||||
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||||
|
|
||||||
# TODO: Check if re-using same envs for eval could reduce memory usage
|
# TODO: Check if re-using same envs for eval could reduce memory usage
|
||||||
envs, eval_envs = make_env(
|
envs, eval_envs = make_env(
|
||||||
@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
|||||||
from mani_skill.utils import gym_utils
|
from mani_skill.utils import gym_utils
|
||||||
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
|
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
|
||||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
from src.env_utils.torch_wrappers.maniskill_wrapper import (
|
from reppo_alg.env_utils.torch_wrappers.maniskill_wrapper import (
|
||||||
ManiSkillWrapper,
|
ManiSkillWrapper,
|
||||||
)
|
)
|
||||||
|
|
@ -29,7 +29,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from src.torchrl.reppo_util import (
|
from reppo_alg.torchrl.reppo import (
|
||||||
EmpiricalNormalization,
|
EmpiricalNormalization,
|
||||||
PerTaskRewardNormalizer,
|
PerTaskRewardNormalizer,
|
||||||
RewardNormalizer,
|
RewardNormalizer,
|
||||||
@ -42,10 +42,6 @@ from torch.amp import GradScaler, autocast
|
|||||||
|
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
try:
|
|
||||||
import jax.numpy as jnp
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -90,7 +86,7 @@ def main():
|
|||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
|
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
|
||||||
from src.env_utils.torch_wrappers.humanoid_bench_env import (
|
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||||
HumanoidBenchEnv,
|
HumanoidBenchEnv,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,7 +94,7 @@ def main():
|
|||||||
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
|
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
|
||||||
eval_envs = envs
|
eval_envs = envs
|
||||||
elif args.env_name.startswith("Isaac-"):
|
elif args.env_name.startswith("Isaac-"):
|
||||||
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||||
|
|
||||||
env_type = "isaaclab"
|
env_type = "isaaclab"
|
||||||
envs = IsaacLabEnv(
|
envs = IsaacLabEnv(
|
||||||
@ -110,14 +106,14 @@ def main():
|
|||||||
)
|
)
|
||||||
eval_envs = envs
|
eval_envs = envs
|
||||||
elif args.env_name.startswith("MTBench-"):
|
elif args.env_name.startswith("MTBench-"):
|
||||||
from src.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
from reppo_alg.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
|
||||||
|
|
||||||
env_name = "-".join(args.env_name.split("-")[1:])
|
env_name = "-".join(args.env_name.split("-")[1:])
|
||||||
env_type = "mtbench"
|
env_type = "mtbench"
|
||||||
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
|
||||||
eval_envs = envs
|
eval_envs = envs
|
||||||
else:
|
else:
|
||||||
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||||
|
|
||||||
# TODO: Check if re-using same envs for eval could reduce memory usage
|
# TODO: Check if re-using same envs for eval could reduce memory usage
|
||||||
env_type = "mujoco_playground"
|
env_type = "mujoco_playground"
|
||||||
@ -133,11 +129,11 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
n_act = envs.num_actions
|
n_act = envs.num_actions
|
||||||
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
|
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
n_critic_obs = (
|
n_critic_obs = (
|
||||||
envs.num_privileged_obs
|
envs.num_privileged_obs
|
||||||
if type(envs.num_privileged_obs) == int
|
if isinstance(envs.num_privileged_obs, int)
|
||||||
else envs.num_privileged_obs[0]
|
else envs.num_privileged_obs[0]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -198,7 +194,7 @@ def main():
|
|||||||
|
|
||||||
if args.agent == "fasttd3":
|
if args.agent == "fasttd3":
|
||||||
if env_type in ["mtbench"]:
|
if env_type in ["mtbench"]:
|
||||||
from src.network_utils.fast_td3_nets import (
|
from reppo_alg.network_utils.fast_td3_nets import (
|
||||||
MultiTaskActor,
|
MultiTaskActor,
|
||||||
MultiTaskCritic,
|
MultiTaskCritic,
|
||||||
)
|
)
|
||||||
@ -206,7 +202,7 @@ def main():
|
|||||||
actor_cls = MultiTaskActor
|
actor_cls = MultiTaskActor
|
||||||
critic_cls = MultiTaskCritic
|
critic_cls = MultiTaskCritic
|
||||||
else:
|
else:
|
||||||
from src.network_utils.fast_td3_nets import Actor, Critic
|
from reppo_alg.network_utils.fast_td3_nets import Actor, Critic
|
||||||
|
|
||||||
actor_cls = Actor
|
actor_cls = Actor
|
||||||
critic_cls = Critic
|
critic_cls = Critic
|
||||||
@ -214,7 +210,7 @@ def main():
|
|||||||
print("Using FastTD3")
|
print("Using FastTD3")
|
||||||
elif args.agent == "fasttd3_simbav2":
|
elif args.agent == "fasttd3_simbav2":
|
||||||
if env_type in ["mtbench"]:
|
if env_type in ["mtbench"]:
|
||||||
from src.network_utils.fast_td3_nets_simbav2 import (
|
from reppo_alg.network_utils.fast_td3_nets_simbav2 import (
|
||||||
MultiTaskActor,
|
MultiTaskActor,
|
||||||
MultiTaskCritic,
|
MultiTaskCritic,
|
||||||
)
|
)
|
||||||
@ -222,7 +218,7 @@ def main():
|
|||||||
actor_cls = MultiTaskActor
|
actor_cls = MultiTaskActor
|
||||||
critic_cls = MultiTaskCritic
|
critic_cls = MultiTaskCritic
|
||||||
else:
|
else:
|
||||||
from src.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
from reppo_alg.network_utils.fast_td3_nets_simbav2 import Actor, Critic
|
||||||
|
|
||||||
actor_cls = Actor
|
actor_cls = Actor
|
||||||
critic_cls = Critic
|
critic_cls = Critic
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||||||
import tyro
|
import tyro
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseArgs:
|
class BaseArgs:
|
||||||
# Default hyperparameters -- specifically for HumanoidBench
|
# Default hyperparameters -- specifically for HumanoidBench
|
||||||
# See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground
|
# See MuJoCoPlaygroundArgs for default hyperparameters for MuJoCo Playground
|
||||||
@ -153,7 +153,6 @@ def get_args():
|
|||||||
"h1hand-basketball-v0": H1HandBasketballArgs,
|
"h1hand-basketball-v0": H1HandBasketballArgs,
|
||||||
"h1hand-window-v0": H1HandWindowArgs,
|
"h1hand-window-v0": H1HandWindowArgs,
|
||||||
"h1hand-package-v0": H1HandPackageArgs,
|
"h1hand-package-v0": H1HandPackageArgs,
|
||||||
"h1hand-truck-v0": H1HandTruckArgs,
|
|
||||||
# MuJoCo Playground
|
# MuJoCo Playground
|
||||||
# NOTE: These tasks are not full list of MuJoCo Playground tasks
|
# NOTE: These tasks are not full list of MuJoCo Playground tasks
|
||||||
"G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs,
|
"G1JoystickFlatTerrain": G1JoystickFlatTerrainArgs,
|
||||||
@ -275,13 +274,6 @@ class H1HandPackageArgs(HumanoidBenchArgs):
|
|||||||
v_max: float = 10000.0
|
v_max: float = 10000.0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class H1HandTruckArgs(HumanoidBenchArgs):
|
|
||||||
env_name: str = "h1hand-truck-v0"
|
|
||||||
v_min: float = -1000.0
|
|
||||||
v_max: float = 1000.0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MuJoCoPlaygroundArgs(BaseArgs):
|
class MuJoCoPlaygroundArgs(BaseArgs):
|
||||||
# Default hyperparameters for many of Playground environments
|
# Default hyperparameters for many of Playground environments
|
||||||
@ -292,6 +284,7 @@ class MuJoCoPlaygroundArgs(BaseArgs):
|
|||||||
num_eval_envs: int = 1024
|
num_eval_envs: int = 1024
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MTBenchArgs(BaseArgs):
|
class MTBenchArgs(BaseArgs):
|
||||||
# Default hyperparameters for MTBench
|
# Default hyperparameters for MTBench
|
@ -12,7 +12,7 @@ from omegaconf import DictConfig, OmegaConf
|
|||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss
|
from reppo_alg.torchrl.reppo import EmpiricalNormalization, hl_gauss
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Required for avoiding IsaacGym import error
|
# Required for avoiding IsaacGym import error
|
||||||
@ -28,13 +28,8 @@ import torch.optim as optim
|
|||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torch.amp import GradScaler
|
from torch.amp import GradScaler
|
||||||
from src.torchrl.envs import make_envs
|
from reppo_alg.torchrl.envs import make_envs
|
||||||
from src.network_utils.torch_models import Actor, Critic
|
from reppo_alg.network_utils.torch_models import Actor, Critic
|
||||||
|
|
||||||
try:
|
|
||||||
import jax.numpy as jnp
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("medium")
|
torch.set_float32_matmul_precision("medium")
|
||||||
@ -484,11 +479,11 @@ def main(cfg):
|
|||||||
envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed)
|
envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed)
|
||||||
|
|
||||||
n_act = envs.num_actions
|
n_act = envs.num_actions
|
||||||
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
|
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
n_critic_obs = (
|
n_critic_obs = (
|
||||||
envs.num_privileged_obs
|
envs.num_privileged_obs
|
||||||
if type(envs.num_privileged_obs) == int
|
if isinstance(envs.num_privileged_obs, int)
|
||||||
else envs.num_privileged_obs[0]
|
else envs.num_privileged_obs[0]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
0
reppo_alg/torchrl/tensordict_replay_buffer.py
Normal file
0
reppo_alg/torchrl/tensordict_replay_buffer.py
Normal file
Loading…
Reference in New Issue
Block a user