fix observation history indexing in the dataset

This commit is contained in:
allenzren 2024-09-17 12:43:15 -04:00
parent 1aaa6c2302
commit ef5b14f820

View File

@ -91,7 +91,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
actions = self.actions[start:end] actions = self.actions[start:end]
states = torch.stack( states = torch.stack(
[ [
states[min(num_before_start - t, 0)] states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps)) for t in reversed(range(self.cond_steps))
] ]
) # more recent is at the end ) # more recent is at the end
@ -100,7 +100,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
images = self.images[(start - num_before_start) : end] images = self.images[(start - num_before_start) : end]
images = torch.stack( images = torch.stack(
[ [
images[min(num_before_start - t, 0)] images[max(num_before_start - t, 0)]
for t in reversed(range(self.img_cond_steps)) for t in reversed(range(self.img_cond_steps))
] ]
) )