From 011cbce7f8abdacf8323881997ca017a401a826e Mon Sep 17 00:00:00 2001 From: Axel Brunnbauer Date: Tue, 15 Jul 2025 22:35:58 -0700 Subject: [PATCH] fix torch implementation --- src/network_utils/torch_models.py | 2 +- src/torchrl/envs.py | 8 ++++---- src/torchrl/reppo.py | 9 +++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/network_utils/torch_models.py b/src/network_utils/torch_models.py index 7ebd8e5..8457812 100644 --- a/src/network_utils/torch_models.py +++ b/src/network_utils/torch_models.py @@ -4,7 +4,7 @@ from torch.distributions import constraints from torch.distributions.transforms import Transform from torch.distributions.normal import Normal -from src.torchrl.reppo import hl_gauss +from src.torchrl.reppo_util import hl_gauss class TanhTransform(Transform): diff --git a/src/torchrl/envs.py b/src/torchrl/envs.py index 15e71ac..14a0592 100644 --- a/src/torchrl/envs.py +++ b/src/torchrl/envs.py @@ -4,7 +4,7 @@ from omegaconf import DictConfig def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: if cfg.env.type == "humanoid_bench": - from reppo.env_utils.torch_wrappers.humanoid_bench_env import ( + from src.env_utils.torch_wrappers.humanoid_bench_env import ( HumanoidBenchEnv, ) @@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: ) return envs, envs elif cfg.env.type == "isaaclab": - from reppo.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv + from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv envs = IsaacLabEnv( cfg.env.name, @@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: ) return envs, envs elif cfg.env.type == "mjx": - from reppo.env_utils.torch_wrappers.mujoco_playground_env import make_env + from src.env_utils.torch_wrappers.mujoco_playground_env import make_env # TODO: Check if re-using same envs for eval could reduce memory usage envs, eval_envs = make_env( @@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple: from mani_skill.utils import gym_utils from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv - from reppo.env_utils.torch_wrappers.maniskill_wrapper import ( + from src.env_utils.torch_wrappers.maniskill_wrapper import ( ManiSkillWrapper, ) diff --git a/src/torchrl/reppo.py b/src/torchrl/reppo.py index d0f269d..ddeae49 100644 --- a/src/torchrl/reppo.py +++ b/src/torchrl/reppo.py @@ -12,6 +12,8 @@ from omegaconf import DictConfig, OmegaConf import wandb +from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss + try: # Required for avoiding IsaacGym import error import isaacgym @@ -28,10 +30,6 @@ from tensordict import TensorDict from torch.amp import GradScaler from src.torchrl.envs import make_envs from src.network_utils.torch_models import Actor, Critic -from src.torchrl.reppo import ( - EmpiricalNormalization, - hl_gauss, -) try: import jax.numpy as jnp @@ -445,10 +443,9 @@ def configure_platform(cfg: DictConfig) -> DictConfig: @hydra.main( version_base=None, config_path="../../config", - config_name="sac", + config_name="reppo", ) def main(cfg): - cfg.hyperparameters = OmegaConf.merge(cfg.hyperparameters, cfg.experiment_overrides) cfg = configure_platform(cfg) run_name = f"{cfg.env.name}_torch_{cfg.seed}"