dppo/env/gym_utils/wrapper/furniture.py
2024-09-11 21:09:17 -04:00

151 lines
5.3 KiB
Python

"""
Environment wrapper for Furniture-Bench environments.
"""
import gym
import numpy as np
from furniture_bench.envs.furniture_rl_sim_env import FurnitureRLSimEnv
import torch
from furniture_bench.controllers.control_utils import proprioceptive_quat_to_6d_rotation
from ..furniture_normalizer import LinearNormalizer
from .multi_step import repeated_space
import logging
log = logging.getLogger(__name__)
class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
env: FurnitureRLSimEnv
def __init__(
self,
env: FurnitureRLSimEnv,
n_obs_steps=1,
n_action_steps=1,
max_episode_steps=None,
sparse_reward=False,
reward_agg_method="sum", # never use other types
reset_within_step=False,
pass_full_observations=False,
normalization_path=None,
prev_action=False,
):
assert (
not reset_within_step
), "reset_within_step must be False for furniture envs"
assert n_obs_steps == 1, "n_obs_steps must be 1"
assert reward_agg_method == "sum", "reward_agg_method must be sum"
assert (
not pass_full_observations
), "pass_full_observations is not implemented yet"
assert not prev_action, "prev_action is not implemented yet"
super().__init__(env)
self._single_action_space = env.action_space
self._action_space = repeated_space(env.action_space, n_action_steps)
self._observation_space = repeated_space(env.observation_space, n_obs_steps)
self.max_episode_steps = max_episode_steps
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.pass_full_observations = pass_full_observations
# Use the original reward function where the robot does not receive new reward after completing one part
self.sparse_reward = sparse_reward
# set up normalization
self.normalize = normalization_path is not None
self.normalizer = LinearNormalizer()
self.normalizer.load_state_dict(
torch.load(normalization_path, map_location=self.device, weights_only=True)
)
log.info(f"Loaded normalization from {normalization_path}")
def reset(
self,
**kwargs,
):
"""Resets the environment."""
obs = self.env.reset()
nobs = self.process_obs(obs)
self.best_reward = torch.zeros(self.env.num_envs).to(self.device)
self.done = list()
return {"state": nobs}
def reset_arg(self, options_list=None):
return self.reset()
def reset_one_arg(self, env_ind=None, options=None):
if env_ind is not None:
env_ind = torch.tensor([env_ind], device=self.device)
return self.reset()
def step(self, action: np.ndarray):
"""
Takes in a chunk of actions of length n_action_steps
and steps the environment n_action_steps times
and returns an aggregated observation, reward, and done signal
"""
# action: (n_envs, n_action_steps, action_dim)
action = torch.tensor(action, device=self.device)
# Denormalize the action
action = self.normalizer(action, "actions", forward=False)
# Step the environment n_action_steps times
obs, sparse_reward, dense_reward, done, info = self._inner_step(action)
if self.sparse_reward:
reward = sparse_reward.clone().cpu().numpy()
else:
reward = dense_reward.clone().cpu().numpy()
# Only mark the environment as done if it times out, ignore done from inner steps
truncated = self.env.env_steps >= self.max_env_steps
done = truncated
nobs: np.ndarray = self.process_obs(obs)
done: np.ndarray = done.squeeze().cpu().numpy()
return {"state": nobs}, reward, done, info
def _inner_step(self, action_chunk: torch.Tensor):
dones = torch.zeros(
action_chunk.shape[0], dtype=torch.bool, device=action_chunk.device
)
dense_reward = torch.zeros(action_chunk.shape[0], device=action_chunk.device)
sparse_reward = torch.zeros(action_chunk.shape[0], device=action_chunk.device)
for i in range(self.n_action_steps):
# The dimensions of the action_chunk are (num_envs, chunk_size, action_dim)
obs, reward, done, info = self.env.step(action_chunk[:, i, :])
# track raw reward
sparse_reward += reward.squeeze()
# track best reward --- reward nonzero only one part is assembled
self.best_reward += reward.squeeze()
# assign "permanent" rewards
dense_reward += self.best_reward
dones = dones | done.squeeze()
return obs, sparse_reward, dense_reward, dones, info
def process_obs(self, obs: torch.Tensor) -> np.ndarray:
robot_state = obs["robot_state"]
# Convert the robot state to have 6D pose
robot_state = proprioceptive_quat_to_6d_rotation(robot_state)
parts_poses = obs["parts_poses"]
obs = torch.cat([robot_state, parts_poses], dim=-1)
nobs = self.normalizer(obs, "observations", forward=True)
nobs = torch.clamp(nobs, -5, 5)
# Insert a dummy dimension for the n_obs_steps (n_envs, obs_dim) -> (n_envs, n_obs_steps, obs_dim)
nobs = nobs.unsqueeze(1).cpu().numpy()
return nobs