reppo/reppo_alg/env_utils/torch_wrappers/maniskill_wrapper.py
2025-07-21 18:31:20 -04:00

90 lines
3.2 KiB
Python

from gymnasium import Wrapper
import torch
class ManiSkillWrapper(Wrapper):
"""
A wrapper for ManiSkill environments to ensure compatibility with the expected API.
This wrapper is used to handle the ManiSkill environments in a way that is consistent
with the other environments in the codebase.
"""
def __init__(self, env, max_episode_steps: int, partial_reset, device: str):
super().__init__(env)
self.action_space = env.action_space
self.observation_space = env.observation_space
self.metadata = env.metadata
self.asymmetric_obs = False
self.max_episode_steps = max_episode_steps
self.partial_reset = partial_reset
self.returns = torch.zeros(env.num_envs, dtype=torch.float32, device=device)
self.episode_len = torch.zeros(env.num_envs, dtype=torch.float32, device=device)
self.success = torch.zeros(env.num_envs, dtype=torch.float32, device=device)
@property
def unwrapped(self):
"""
Returns the underlying environment.
"""
return self.env
@property
def num_actions(self):
"""
Returns the number of actions in the action space.
"""
return self.action_space.shape[1]
@property
def num_obs(self):
"""
Returns the number of observations in the observation space.
"""
return self.observation_space.shape[1]
def reset(self, seed=None, options=dict()):
"""
Resets the environment and returns the initial observation.
"""
return self.env.reset(seed=seed, options=options)
def step(self, action):
"""
Takes a step in the environment with the given action.
Returns the next observation, reward, done, and info.
"""
obs, reward, terminated, truncated, info = self.env.step(action)
if "final_info" in info:
self.returns = (
info["final_info"]["episode"]["return"] * info["_final_info"].float()
+ (1.0 - info["_final_info"].float()) * self.returns
)
self.episode_len = (
info["final_info"]["episode"]["episode_len"]
* info["_final_info"].float()
+ (1.0 - info["_final_info"].float()) * self.episode_len
)
self.success = (
info["final_info"]["episode"]["success_once"]
* info["_final_info"].float()
+ (1.0 - info["_final_info"].float()) * self.success
)
info["log_info"] = {
"return": self.returns,
"episode_len": self.episode_len,
"success": self.success,
}
if self.partial_reset:
# maniskill continues bootstrap on terminated, which playground does on truncated.
# This unifies the interfaces in a very hacky way
done = torch.zeros_like(
terminated, dtype=torch.bool, device=terminated.device
)
truncated = torch.logical_or(terminated, truncated)
else:
done = torch.logical_or(terminated, truncated)
truncated = torch.zeros_like(done, dtype=torch.bool, device=done.device)
return obs, reward, done, truncated, info