fix observation history indexing in the dataset

This commit is contained in:
allenzren 2024-09-17 12:43:15 -04:00
parent 1aaa6c2302
commit ef5b14f820

View File

@ -91,7 +91,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
actions = self.actions[start:end]
states = torch.stack(
[
states[min(num_before_start - t, 0)]
states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
@ -100,7 +100,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
images = self.images[(start - num_before_start) : end]
images = torch.stack(
[
images[min(num_before_start - t, 0)]
images[max(num_before_start - t, 0)]
for t in reversed(range(self.img_cond_steps))
]
)