from gymnasium import Wrapper import torch class ManiSkillWrapper(Wrapper): """ A wrapper for ManiSkill environments to ensure compatibility with the expected API. This wrapper is used to handle the ManiSkill environments in a way that is consistent with the other environments in the codebase. """ def __init__(self, env, max_episode_steps: int, partial_reset, device: str): super().__init__(env) self.action_space = env.action_space self.observation_space = env.observation_space self.metadata = env.metadata self.asymmetric_obs = False self.max_episode_steps = max_episode_steps self.partial_reset = partial_reset self.returns = torch.zeros(env.num_envs, dtype=torch.float32, device=device) self.episode_len = torch.zeros(env.num_envs, dtype=torch.float32, device=device) self.success = torch.zeros(env.num_envs, dtype=torch.float32, device=device) @property def unwrapped(self): """ Returns the underlying environment. """ return self.env @property def num_actions(self): """ Returns the number of actions in the action space. """ return self.action_space.shape[1] @property def num_obs(self): """ Returns the number of observations in the observation space. """ return self.observation_space.shape[1] def reset(self, seed=None, options=dict()): """ Resets the environment and returns the initial observation. """ return self.env.reset(seed=seed, options=options) def step(self, action): """ Takes a step in the environment with the given action. Returns the next observation, reward, done, and info. """ obs, reward, terminated, truncated, info = self.env.step(action) if "final_info" in info: self.returns = ( info["final_info"]["episode"]["return"] * info["_final_info"].float() + (1.0 - info["_final_info"].float()) * self.returns ) self.episode_len = ( info["final_info"]["episode"]["episode_len"] * info["_final_info"].float() + (1.0 - info["_final_info"].float()) * self.episode_len ) self.success = ( info["final_info"]["episode"]["success_once"] * info["_final_info"].float() + (1.0 - info["_final_info"].float()) * self.success ) info["log_info"] = { "return": self.returns, "episode_len": self.episode_len, "success": self.success, } if self.partial_reset: # maniskill continues bootstrap on terminated, which playground does on truncated. # This unifies the interfaces in a very hacky way done = torch.zeros_like( terminated, dtype=torch.bool, device=terminated.device ) truncated = torch.logical_or(terminated, truncated) else: done = torch.logical_or(terminated, truncated) truncated = torch.zeros_like(done, dtype=torch.bool, device=done.device) return obs, reward, done, truncated, info