115 lines
4.3 KiB
Python
115 lines
4.3 KiB
Python
"""
|
|
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
|
|
|
|
TODO: implement history observation
|
|
|
|
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
|
|
|
|
"""
|
|
|
|
from collections import namedtuple
|
|
import numpy as np
|
|
import torch
|
|
import logging
|
|
import pickle
|
|
import random
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
Batch = namedtuple("Batch", "trajectories conditions")
|
|
|
|
|
|
class StitchedSequenceDataset(torch.utils.data.Dataset):
|
|
"""
|
|
Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories.
|
|
|
|
Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is
|
|
(tuple of) dimension of observation, action, images, etc.
|
|
|
|
Example:
|
|
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
|
|
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_path,
|
|
horizon_steps=64,
|
|
cond_steps=1,
|
|
max_n_episodes=10000,
|
|
use_img=False,
|
|
device="cuda:0",
|
|
):
|
|
self.horizon_steps = horizon_steps
|
|
self.cond_steps = cond_steps
|
|
self.device = device
|
|
self.use_img = use_img
|
|
|
|
# Load dataset to device specified
|
|
if dataset_path.endswith(".npz"):
|
|
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
|
|
else:
|
|
with open(dataset_path, "rb") as f:
|
|
dataset = pickle.load(f)
|
|
traj_lengths = dataset["traj_lengths"] # 1-D array
|
|
total_num_steps = np.sum(traj_lengths[:max_n_episodes])
|
|
|
|
# Set up indices for sampling
|
|
self.indices = self.make_indices(traj_lengths, horizon_steps)
|
|
|
|
# Extract states and actions up to max_n_episodes
|
|
self.states = (
|
|
torch.from_numpy(dataset["states"][:total_num_steps]).float().to(device)
|
|
) # (total_num_steps, obs_dim)
|
|
self.actions = (
|
|
torch.from_numpy(dataset["actions"][:total_num_steps]).float().to(device)
|
|
) # (total_num_steps, action_dim)
|
|
log.info(f"Loaded dataset from {dataset_path}")
|
|
log.info(f"Number of episodes: {min(max_n_episodes, len(traj_lengths))}")
|
|
log.info(f"States shape/type: {self.states.shape, self.states.dtype}")
|
|
log.info(f"Actions shape/type: {self.actions.shape, self.actions.dtype}")
|
|
if self.use_img:
|
|
self.images = torch.from_numpy(dataset["images"][:total_num_steps]).to(
|
|
device
|
|
) # (total_num_steps, C, H, W)
|
|
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
|
|
|
|
def __getitem__(self, idx):
|
|
start = self.indices[idx]
|
|
end = start + self.horizon_steps
|
|
states = self.states[start:end]
|
|
actions = self.actions[start:end]
|
|
if self.use_img:
|
|
images = self.images[start:end]
|
|
conditions = {
|
|
1 - self.cond_steps: {"state": states[0], "rgb": images[0]}
|
|
} # TODO: allow obs history, -1, -2, ...
|
|
else:
|
|
conditions = {1 - self.cond_steps: states[0]}
|
|
batch = Batch(actions, conditions)
|
|
return batch
|
|
|
|
def make_indices(self, traj_lengths, horizon_steps):
|
|
"""
|
|
makes indices for sampling from dataset;
|
|
each index maps to a datapoint
|
|
"""
|
|
indices = []
|
|
cur_traj_index = 0
|
|
for traj_length in traj_lengths:
|
|
max_start = cur_traj_index + traj_length - horizon_steps + 1
|
|
indices += list(range(cur_traj_index, max_start))
|
|
cur_traj_index += traj_length
|
|
return indices
|
|
|
|
def set_train_val_split(self, train_split):
|
|
"""Not doing validation right now"""
|
|
num_train = int(len(self.indices) * train_split)
|
|
train_indices = random.sample(self.indices, num_train)
|
|
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
|
|
self.indices = train_indices
|
|
return val_indices
|
|
|
|
def __len__(self):
|
|
return len(self.indices)
|