reppo/reppo_alg/env_utils/torch_wrappers/isaaclab_env.py
2025-07-21 18:31:20 -04:00

82 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional
import gymnasium as gym
import torch
from isaaclab.app import AppLauncher
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
app_launcher = AppLauncher(headless=True)
simulation_app = app_launcher.app
class IsaacLabEnv:
"""Wrapper for IsaacLab environments to be compatible with MuJoCo Playground"""
def __init__(
self,
task_name: str,
device: str,
num_envs: int,
seed: int,
action_bounds: Optional[float] = None,
):
env_cfg = parse_env_cfg(
task_name,
device=device,
num_envs=num_envs,
)
env_cfg.seed = seed
self.seed = seed
self.envs = gym.make(task_name, cfg=env_cfg, render_mode=None)
self.num_envs = self.envs.unwrapped.num_envs
self.max_episode_steps = self.envs.unwrapped.max_episode_length
self.action_bounds = action_bounds
self.num_obs = self.envs.unwrapped.single_observation_space["policy"].shape[0]
self.asymmetric_obs = "critic" in self.envs.unwrapped.single_observation_space
if self.asymmetric_obs:
self.num_privileged_obs = self.envs.unwrapped.single_observation_space[
"critic"
].shape[0]
else:
self.num_privileged_obs = 0
self.num_actions = self.envs.unwrapped.single_action_space.shape[0]
def reset(self, random_start_init: bool = True) -> torch.Tensor:
obs_dict, _ = self.envs.reset()
# NOTE: decorrelate episode horizons like RSLRL
if random_start_init:
self.envs.unwrapped.episode_length_buf = torch.randint_like(
self.envs.unwrapped.episode_length_buf, high=int(self.max_episode_steps)
)
return obs_dict["policy"]
def reset_with_critic_obs(self) -> tuple[torch.Tensor, torch.Tensor]:
obs_dict, _ = self.envs.reset()
return obs_dict["policy"], obs_dict["critic"]
def step(
self, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
if self.action_bounds is not None:
actions = torch.clamp(actions, -1.0, 1.0) * self.action_bounds
obs_dict, rew, terminations, truncations, infos = self.envs.step(actions)
dones = (terminations | truncations).to(dtype=torch.long)
obs = obs_dict["policy"]
critic_obs = obs_dict["critic"] if self.asymmetric_obs else None
info_ret = {"time_outs": truncations, "observations": {"critic": critic_obs}}
# NOTE: There's really no way to get the raw observations from IsaacLab
# We just use the 'reset_obs' as next_obs, unfortunately.
# See https://github.com/isaac-sim/IsaacLab/issues/1362
info_ret["observations"]["raw"] = {
"obs": obs,
"critic_obs": critic_obs,
}
return obs, rew, dones, info_ret
def render(self):
raise NotImplementedError(
"We don't support rendering for IsaacLab environments"
)