dppo/agent/dataset/sequence.py

138 lines
5.1 KiB
Python

"""
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
"""
from collections import namedtuple
import numpy as np
import torch
import logging
import pickle
import random
log = logging.getLogger(__name__)
Batch = namedtuple("Batch", "actions conditions")
class StitchedSequenceDataset(torch.utils.data.Dataset):
"""
Load stitched trajectories of states/actions/images, and 1-D array of traj_lengths, from npz or pkl file.
Use the first max_n_episodes episodes (instead of random sampling)
Example:
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
Episode IDs (determined based on traj_lengths): [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
Each sample is a namedtuple of (1) chunked actions and (2) a list (obs timesteps) of dictionary with keys states and images.
"""
def __init__(
self,
dataset_path,
horizon_steps=64,
cond_steps=1,
img_cond_steps=1,
max_n_episodes=10000,
use_img=False,
device="cuda:0",
):
assert (
img_cond_steps <= cond_steps
), "consider using more cond_steps than img_cond_steps"
self.horizon_steps = horizon_steps
self.cond_steps = cond_steps # states (proprio, etc.)
self.img_cond_steps = img_cond_steps
self.device = device
self.use_img = use_img
# Load dataset to device specified
if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
elif dataset_path.endswith(".pkl"):
with open(dataset_path, "rb") as f:
dataset = pickle.load(f)
else:
raise ValueError(f"Unsupported file format: {dataset_path}")
traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array
total_num_steps = np.sum(traj_lengths)
# Set up indices for sampling
self.indices = self.make_indices(traj_lengths, horizon_steps)
# Extract states and actions up to max_n_episodes
self.states = (
torch.from_numpy(dataset["states"][:total_num_steps]).float().to(device)
) # (total_num_steps, obs_dim)
self.actions = (
torch.from_numpy(dataset["actions"][:total_num_steps]).float().to(device)
) # (total_num_steps, action_dim)
log.info(f"Loaded dataset from {dataset_path}")
log.info(f"Number of episodes: {min(max_n_episodes, len(traj_lengths))}")
log.info(f"States shape/type: {self.states.shape, self.states.dtype}")
log.info(f"Actions shape/type: {self.actions.shape, self.actions.dtype}")
if self.use_img:
self.images = torch.from_numpy(dataset["images"][:total_num_steps]).to(
device
) # (total_num_steps, C, H, W)
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
def __getitem__(self, idx):
"""
repeat states/images if using history observation at the beginning of the episode
"""
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[(start - num_before_start) : end]
actions = self.actions[start:end]
states = torch.stack(
[
states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
conditions = {"state": states}
if self.use_img:
images = self.images[(start - num_before_start) : end]
images = torch.stack(
[
images[max(num_before_start - t, 0)]
for t in reversed(range(self.img_cond_steps))
]
)
conditions["rgb"] = images
batch = Batch(actions, conditions)
return batch
def make_indices(self, traj_lengths, horizon_steps):
"""
makes indices for sampling from dataset;
each index maps to a datapoint, also save the number of steps before it within the same trajectory
"""
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps + 1
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start)
]
cur_traj_index += traj_length
return indices
def set_train_val_split(self, train_split):
"""
Not doing validation right now
"""
num_train = int(len(self.indices) * train_split)
train_indices = random.sample(self.indices, num_train)
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
self.indices = train_indices
return val_indices
def __len__(self):
return len(self.indices)