reppo/reppo_alg/torchrl/envs.py
2025-07-21 18:31:20 -04:00

96 lines
3.4 KiB
Python

import torch
from omegaconf import DictConfig
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
if cfg.env.type == "humanoid_bench":
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv,
)
envs = HumanoidBenchEnv(
cfg.env.name, cfg.hyperparameters.num_envs, device=device
)
return envs, envs
elif cfg.env.type == "isaaclab":
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
envs = IsaacLabEnv(
cfg.env.name,
device.type,
cfg.hyperparameters.num_envs,
cfg=seed,
action_bounds=cfg.env.action_bounds,
)
return envs, envs
elif cfg.env.type == "mjx":
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
envs, eval_envs = make_env(
env_name=cfg.env.name,
seed=seed,
num_envs=cfg.hyperparameters.num_envs,
num_eval_envs=cfg.hyperparameters.num_envs,
device_rank=cfg.platform.device_rank,
use_domain_randomization=False,
use_push_randomization=True,
)
return envs, eval_envs
elif cfg.env.type == "maniskill":
import gymnasium as gym
import mani_skill.envs # noqa: F401
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from reppo_alg.env_utils.torch_wrappers.maniskill_wrapper import (
ManiSkillWrapper,
)
envs = gym.make(
cfg.env.name,
num_envs=cfg.hyperparameters.num_envs,
reconfiguration_freq=None,
**cfg.env.env_kwargs,
)
eval_envs = gym.make(
cfg.env.name,
num_envs=cfg.hyperparameters.num_envs,
reconfiguration_freq=1,
**cfg.env.env_kwargs,
)
cfg.env.max_episode_steps = gym_utils.find_max_episode_steps_value(envs)
# heuristic for setting gamma
cfg.hyperparameters.gamma = 1.0 - 10.0 / cfg.env.max_episode_steps
if isinstance(envs.action_space, gym.spaces.Dict):
envs = FlattenActionSpaceWrapper(envs)
eval_envs = FlattenActionSpaceWrapper(eval_envs)
envs = ManiSkillVectorEnv(
envs,
cfg.hyperparameters.num_envs,
ignore_terminations=not cfg.env.partial_reset,
record_metrics=True,
)
eval_envs = ManiSkillVectorEnv(
eval_envs,
cfg.hyperparameters.num_envs,
ignore_terminations=True,
record_metrics=True,
)
return ManiSkillWrapper(
envs,
max_episode_steps=cfg.env.max_episode_steps,
partial_reset=cfg.env.partial_reset,
device=device.type,
), ManiSkillWrapper(
eval_envs,
max_episode_steps=cfg.env.max_episode_steps,
partial_reset=cfg.env.partial_reset,
device=device.type,
)
else:
raise ValueError(
f"Unknown environment type: {cfg.env.type}. Supported types are 'humanoid_bench', 'isaaclab', 'maniskill', and 'mjx'."
)