fix torch implementation
This commit is contained in:
parent
86fd47b04e
commit
011cbce7f8
@ -4,7 +4,7 @@ from torch.distributions import constraints
|
||||
from torch.distributions.transforms import Transform
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
from src.torchrl.reppo import hl_gauss
|
||||
from src.torchrl.reppo_util import hl_gauss
|
||||
|
||||
|
||||
class TanhTransform(Transform):
|
||||
|
@ -4,7 +4,7 @@ from omegaconf import DictConfig
|
||||
|
||||
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
if cfg.env.type == "humanoid_bench":
|
||||
from reppo.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
from src.env_utils.torch_wrappers.humanoid_bench_env import (
|
||||
HumanoidBenchEnv,
|
||||
)
|
||||
|
||||
@ -13,7 +13,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
)
|
||||
return envs, envs
|
||||
elif cfg.env.type == "isaaclab":
|
||||
from reppo.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
from src.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
||||
|
||||
envs = IsaacLabEnv(
|
||||
cfg.env.name,
|
||||
@ -24,7 +24,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
)
|
||||
return envs, envs
|
||||
elif cfg.env.type == "mjx":
|
||||
from reppo.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
from src.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
||||
|
||||
# TODO: Check if re-using same envs for eval could reduce memory usage
|
||||
envs, eval_envs = make_env(
|
||||
@ -43,7 +43,7 @@ def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
||||
from mani_skill.utils import gym_utils
|
||||
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
from reppo.env_utils.torch_wrappers.maniskill_wrapper import (
|
||||
from src.env_utils.torch_wrappers.maniskill_wrapper import (
|
||||
ManiSkillWrapper,
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,8 @@ from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import wandb
|
||||
|
||||
from src.torchrl.reppo_util import EmpiricalNormalization, hl_gauss
|
||||
|
||||
try:
|
||||
# Required for avoiding IsaacGym import error
|
||||
import isaacgym
|
||||
@ -28,10 +30,6 @@ from tensordict import TensorDict
|
||||
from torch.amp import GradScaler
|
||||
from src.torchrl.envs import make_envs
|
||||
from src.network_utils.torch_models import Actor, Critic
|
||||
from src.torchrl.reppo import (
|
||||
EmpiricalNormalization,
|
||||
hl_gauss,
|
||||
)
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
@ -445,10 +443,9 @@ def configure_platform(cfg: DictConfig) -> DictConfig:
|
||||
@hydra.main(
|
||||
version_base=None,
|
||||
config_path="../../config",
|
||||
config_name="sac",
|
||||
config_name="reppo",
|
||||
)
|
||||
def main(cfg):
|
||||
cfg.hyperparameters = OmegaConf.merge(cfg.hyperparameters, cfg.experiment_overrides)
|
||||
cfg = configure_platform(cfg)
|
||||
run_name = f"{cfg.env.name}_torch_{cfg.seed}"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user