163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
"""
|
|
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
|
|
|
|
TODO: implement history observation
|
|
|
|
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
|
|
|
|
"""
|
|
|
|
from collections import namedtuple
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import torch
|
|
import logging
|
|
import pickle
|
|
import random
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
from .buffer import StitchedBuffer
|
|
|
|
|
|
Batch = namedtuple("Batch", "trajectories conditions")
|
|
ValueBatch = namedtuple("ValueBatch", "trajectories conditions values")
|
|
|
|
|
|
class StitchedSequenceDataset(torch.utils.data.Dataset):
|
|
"""
|
|
Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories.
|
|
|
|
Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is
|
|
(tuple of) dimension of observation, action, images, etc.
|
|
|
|
Example:
|
|
Observations: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
|
|
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_path,
|
|
horizon_steps=64,
|
|
cond_steps=1,
|
|
max_n_episodes=10000,
|
|
use_img=False,
|
|
device="cpu",
|
|
):
|
|
self.horizon_steps = horizon_steps
|
|
self.cond_steps = cond_steps
|
|
self.device = device
|
|
|
|
# Load dataset to device specified
|
|
if dataset_path.endswith(".npz"):
|
|
dataset = np.load(dataset_path, allow_pickle=True)
|
|
else:
|
|
with open(dataset_path, "rb") as f:
|
|
dataset = pickle.load(f)
|
|
num_episodes = dataset["observations"].shape[0]
|
|
|
|
# Get the sum total of the valid trajectories' lengths
|
|
traj_lengths = dataset["traj_length"]
|
|
sum_of_path_lengths = np.sum(traj_lengths)
|
|
self.sum_of_path_lengths = sum_of_path_lengths
|
|
|
|
fields = StitchedBuffer(sum_of_path_lengths, device)
|
|
for i in tqdm(
|
|
range(min(max_n_episodes, num_episodes)), desc="Loading trajectories"
|
|
):
|
|
traj_length = traj_lengths[i]
|
|
episode = {
|
|
"observations": dataset["observations"][i][:traj_length],
|
|
"actions": dataset["actions"][i][:traj_length],
|
|
"episode_ids": i * np.ones(traj_length),
|
|
}
|
|
if use_img:
|
|
episode["images"] = dataset["images"][i][:traj_length]
|
|
for key, val in episode.items():
|
|
if device == "cpu":
|
|
episode[key] = val
|
|
else:
|
|
# if None array, save as empty tensor
|
|
if np.all(np.equal(episode[key], None)):
|
|
episode[key] = torch.empty(episode[key].shape).to(device)
|
|
else:
|
|
if key == "images":
|
|
episode[key] = torch.tensor(val, dtype=torch.uint8).to(
|
|
device
|
|
)
|
|
# (, H, W, C) -> (, C, H, W)
|
|
episode[key] = episode[key].permute(0, 3, 1, 2)
|
|
else:
|
|
episode[key] = torch.tensor(val, dtype=torch.float32).to(
|
|
device
|
|
)
|
|
fields.add_path(episode)
|
|
fields.finalize()
|
|
|
|
self.indices = self.make_indices(traj_lengths, horizon_steps)
|
|
self.obs_dim = fields.observations.shape[-1]
|
|
self.action_dim = fields.actions.shape[-1]
|
|
self.fields = fields
|
|
self.n_episodes = fields.n_episodes
|
|
self.path_lengths = fields.path_lengths
|
|
self.traj_lengths = traj_lengths
|
|
self.use_img = use_img
|
|
log.info(fields)
|
|
|
|
def make_indices(self, traj_lengths, horizon_steps):
|
|
"""
|
|
makes indices for sampling from dataset;
|
|
each index maps to a datapoint
|
|
"""
|
|
indices = []
|
|
cur_traj_index = 0
|
|
for traj_length in traj_lengths:
|
|
max_start = cur_traj_index + traj_length - horizon_steps + 1
|
|
indices += list(range(cur_traj_index, max_start))
|
|
cur_traj_index += traj_length
|
|
return indices
|
|
|
|
def set_train_val_split(self, train_split):
|
|
num_train = int(len(self.indices) * train_split)
|
|
train_indices = random.sample(self.indices, num_train)
|
|
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
|
|
self.indices = train_indices
|
|
return val_indices
|
|
|
|
def set_indices(self, indices):
|
|
self.indices = indices
|
|
|
|
def get_conditions(self, observations, images=None):
|
|
"""
|
|
condition on current observation for planning. Take into account the number of conditioning steps.
|
|
"""
|
|
if images is not None:
|
|
return {
|
|
1 - self.cond_steps: {"state": observations[0], "rgb": images[0]}
|
|
} # TODO: allow obs history, -1, -2, ...
|
|
else:
|
|
return {1 - self.cond_steps: observations[0]}
|
|
|
|
def __len__(self):
|
|
return len(self.indices)
|
|
|
|
def __getitem__(self, idx, eps=1e-4):
|
|
raise NotImplementedError("Get item defined in subclass.")
|
|
|
|
|
|
class StitchedActionSequenceDataset(StitchedSequenceDataset):
|
|
"""Only use action trajectory, and then obs_cond for current observation"""
|
|
|
|
def __getitem__(self, idx):
|
|
start = self.indices[idx]
|
|
end = start + self.horizon_steps
|
|
observations = self.fields.observations[start:end]
|
|
actions = self.fields.actions[start:end]
|
|
images = None
|
|
if self.use_img:
|
|
images = self.fields.images[start:end]
|
|
conditions = self.get_conditions(observations, images)
|
|
batch = Batch(actions, conditions)
|
|
return batch
|