96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
import torch
|
|
from omegaconf import DictConfig
|
|
|
|
|
|
def make_envs(cfg: DictConfig, device: torch.device, seed: int = None) -> tuple:
|
|
if cfg.env.type == "humanoid_bench":
|
|
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
|
|
HumanoidBenchEnv,
|
|
)
|
|
|
|
envs = HumanoidBenchEnv(
|
|
cfg.env.name, cfg.hyperparameters.num_envs, device=device
|
|
)
|
|
return envs, envs
|
|
elif cfg.env.type == "isaaclab":
|
|
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
|
|
|
|
envs = IsaacLabEnv(
|
|
cfg.env.name,
|
|
device.type,
|
|
cfg.hyperparameters.num_envs,
|
|
cfg=seed,
|
|
action_bounds=cfg.env.action_bounds,
|
|
)
|
|
return envs, envs
|
|
elif cfg.env.type == "mjx":
|
|
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
|
|
|
|
# TODO: Check if re-using same envs for eval could reduce memory usage
|
|
envs, eval_envs = make_env(
|
|
env_name=cfg.env.name,
|
|
seed=seed,
|
|
num_envs=cfg.hyperparameters.num_envs,
|
|
num_eval_envs=cfg.hyperparameters.num_envs,
|
|
device_rank=cfg.platform.device_rank,
|
|
use_domain_randomization=False,
|
|
use_push_randomization=True,
|
|
)
|
|
return envs, eval_envs
|
|
elif cfg.env.type == "maniskill":
|
|
import gymnasium as gym
|
|
import mani_skill.envs # noqa: F401
|
|
from mani_skill.utils import gym_utils
|
|
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
|
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
|
from reppo_alg.env_utils.torch_wrappers.maniskill_wrapper import (
|
|
ManiSkillWrapper,
|
|
)
|
|
|
|
envs = gym.make(
|
|
cfg.env.name,
|
|
num_envs=cfg.hyperparameters.num_envs,
|
|
reconfiguration_freq=None,
|
|
**cfg.env.env_kwargs,
|
|
)
|
|
eval_envs = gym.make(
|
|
cfg.env.name,
|
|
num_envs=cfg.hyperparameters.num_envs,
|
|
reconfiguration_freq=1,
|
|
**cfg.env.env_kwargs,
|
|
)
|
|
cfg.env.max_episode_steps = gym_utils.find_max_episode_steps_value(envs)
|
|
# heuristic for setting gamma
|
|
cfg.hyperparameters.gamma = 1.0 - 10.0 / cfg.env.max_episode_steps
|
|
|
|
if isinstance(envs.action_space, gym.spaces.Dict):
|
|
envs = FlattenActionSpaceWrapper(envs)
|
|
eval_envs = FlattenActionSpaceWrapper(eval_envs)
|
|
envs = ManiSkillVectorEnv(
|
|
envs,
|
|
cfg.hyperparameters.num_envs,
|
|
ignore_terminations=not cfg.env.partial_reset,
|
|
record_metrics=True,
|
|
)
|
|
eval_envs = ManiSkillVectorEnv(
|
|
eval_envs,
|
|
cfg.hyperparameters.num_envs,
|
|
ignore_terminations=True,
|
|
record_metrics=True,
|
|
)
|
|
return ManiSkillWrapper(
|
|
envs,
|
|
max_episode_steps=cfg.env.max_episode_steps,
|
|
partial_reset=cfg.env.partial_reset,
|
|
device=device.type,
|
|
), ManiSkillWrapper(
|
|
eval_envs,
|
|
max_episode_steps=cfg.env.max_episode_steps,
|
|
partial_reset=cfg.env.partial_reset,
|
|
device=device.type,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown environment type: {cfg.env.type}. Supported types are 'humanoid_bench', 'isaaclab', 'maniskill', and 'mjx'."
|
|
)
|