dppo/agent/dataset/sequence.py
2024-09-03 21:03:27 -04:00

163 lines
6.0 KiB
Python

"""
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
TODO: implement history observation
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
"""
from collections import namedtuple
from tqdm import tqdm
import numpy as np
import torch
import logging
import pickle
import random
log = logging.getLogger(__name__)
from .buffer import StitchedBuffer
Batch = namedtuple("Batch", "trajectories conditions")
ValueBatch = namedtuple("ValueBatch", "trajectories conditions values")
class StitchedSequenceDataset(torch.utils.data.Dataset):
"""
Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories.
Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is
(tuple of) dimension of observation, action, images, etc.
Example:
Observations: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
"""
def __init__(
self,
dataset_path,
horizon_steps=64,
cond_steps=1,
max_n_episodes=10000,
use_img=False,
device="cpu",
):
self.horizon_steps = horizon_steps
self.cond_steps = cond_steps
self.device = device
# Load dataset to device specified
if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=True)
else:
with open(dataset_path, "rb") as f:
dataset = pickle.load(f)
num_episodes = dataset["observations"].shape[0]
# Get the sum total of the valid trajectories' lengths
traj_lengths = dataset["traj_length"]
sum_of_path_lengths = np.sum(traj_lengths)
self.sum_of_path_lengths = sum_of_path_lengths
fields = StitchedBuffer(sum_of_path_lengths, device)
for i in tqdm(
range(min(max_n_episodes, num_episodes)), desc="Loading trajectories"
):
traj_length = traj_lengths[i]
episode = {
"observations": dataset["observations"][i][:traj_length],
"actions": dataset["actions"][i][:traj_length],
"episode_ids": i * np.ones(traj_length),
}
if use_img:
episode["images"] = dataset["images"][i][:traj_length]
for key, val in episode.items():
if device == "cpu":
episode[key] = val
else:
# if None array, save as empty tensor
if np.all(np.equal(episode[key], None)):
episode[key] = torch.empty(episode[key].shape).to(device)
else:
if key == "images":
episode[key] = torch.tensor(val, dtype=torch.uint8).to(
device
)
# (, H, W, C) -> (, C, H, W)
episode[key] = episode[key].permute(0, 3, 1, 2)
else:
episode[key] = torch.tensor(val, dtype=torch.float32).to(
device
)
fields.add_path(episode)
fields.finalize()
self.indices = self.make_indices(traj_lengths, horizon_steps)
self.obs_dim = fields.observations.shape[-1]
self.action_dim = fields.actions.shape[-1]
self.fields = fields
self.n_episodes = fields.n_episodes
self.path_lengths = fields.path_lengths
self.traj_lengths = traj_lengths
self.use_img = use_img
log.info(fields)
def make_indices(self, traj_lengths, horizon_steps):
"""
makes indices for sampling from dataset;
each index maps to a datapoint
"""
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps + 1
indices += list(range(cur_traj_index, max_start))
cur_traj_index += traj_length
return indices
def set_train_val_split(self, train_split):
num_train = int(len(self.indices) * train_split)
train_indices = random.sample(self.indices, num_train)
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
self.indices = train_indices
return val_indices
def set_indices(self, indices):
self.indices = indices
def get_conditions(self, observations, images=None):
"""
condition on current observation for planning. Take into account the number of conditioning steps.
"""
if images is not None:
return {
1 - self.cond_steps: {"state": observations[0], "rgb": images[0]}
} # TODO: allow obs history, -1, -2, ...
else:
return {1 - self.cond_steps: observations[0]}
def __len__(self):
return len(self.indices)
def __getitem__(self, idx, eps=1e-4):
raise NotImplementedError("Get item defined in subclass.")
class StitchedActionSequenceDataset(StitchedSequenceDataset):
"""Only use action trajectory, and then obs_cond for current observation"""
def __getitem__(self, idx):
start = self.indices[idx]
end = start + self.horizon_steps
observations = self.fields.observations[start:end]
actions = self.fields.actions[start:end]
images = None
if self.use_img:
images = self.fields.images[start:end]
conditions = self.get_conditions(observations, images)
batch = Batch(actions, conditions)
return batch