fix torch implementation

This commit is contained in:
Axel Brunnbauer 2025-07-15 22:35:58 -07:00
parent 86fd47b04e
commit 011cbce7f8
3 changed files with 8 additions and 11 deletions

View File

@ -4,7 +4,7 @@ from torch.distributions import constraints
from torch.distributions.transforms import Transform
from torch.distributions.normal import Normal
from src.torchrl.reppo import hl_gauss
from src.torchrl.reppo_util import hl_gauss
class TanhTransform(Transform):

View File

@ -4,7 +4,7 @@ from omegaconf import DictConfig
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
if cfg.env.type == "humanoid_bench":
from reppo.env_utils.torch_wrappers.humanoid_bench_env import (
from src.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv,
)
@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
)
return envs, envs
elif cfg.env.type == "isaaclab":
from reppo.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
envs = IsaacLabEnv(
cfg.env.name,
@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
)
return envs, envs
elif cfg.env.type == "mjx":
from reppo.env_utils.torch_wrappers.mujoco_playground_env import make_env
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
envs, eval_envs = make_env(
@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from reppo.env_utils.torch_wrappers.maniskill_wrapper import (
from src.env_utils.torch_wrappers.maniskill_wrapper import (
ManiSkillWrapper,
)

View File

@ -12,6 +12,8 @@ from omegaconf import DictConfig, OmegaConf
import wandb
from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss
try:
# Required for avoiding IsaacGym import error
import isaacgym
@ -28,10 +30,6 @@ from tensordict import TensorDict
from torch.amp import GradScaler
from src.torchrl.envs import make_envs
from src.network_utils.torch_models import Actor, Critic
from src.torchrl.reppo import (
EmpiricalNormalization,
hl_gauss,
)
try:
import jax.numpy as jnp
@ -445,10 +443,9 @@ def configure_platform(cfg: DictConfig) -> DictConfig:
@hydra.main(
version_base=None,
config_path="../../config",
config_name="sac",
config_name="reppo",
)
def main(cfg):
cfg.hyperparameters = OmegaConf.merge(cfg.hyperparameters, cfg.experiment_overrides)
cfg = configure_platform(cfg)
run_name = f"{cfg.env.name}_torch_{cfg.seed}"